blob: 23ff70a52e48b06f9e98351eadc7333fffaa0c36 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000011#include <armnn/ILayerSupport.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000012#include <armnn/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <backendsCommon/WorkloadFactory.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000015#include <armnn/backends/IBackendInternal.hpp>
16#include <backendsCommon/CpuTensorHandle.hpp>
17#include <backendsCommon/WorkloadFactory.hpp>
18
Francis Murtagh46c09d02019-05-28 08:15:28 +010019#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020
21#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000022#include <boost/iterator/transform_iterator.hpp>
23
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000025#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000026
telsoa014fcda012018-03-09 14:13:49 +000027namespace armnn
28{
29
telsoa01c577f2c2018-08-31 09:22:23 +010030namespace
31{
telsoa01c577f2c2018-08-31 09:22:23 +010032
David Beck29c75de2018-10-23 13:35:58 +010033const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
34{
35 if (!type)
36 {
37 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010038 }
39
David Beck29c75de2018-10-23 13:35:58 +010040 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010041}
42
David Beck29c75de2018-10-23 13:35:58 +010043} // anonymous namespace
44
David Beck33f0ae02018-10-18 15:13:56 +010045bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010046 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010047 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010048 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000049{
David Beck33f0ae02018-10-18 15:13:56 +010050 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000051 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010052 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
53
David Beck111b5d92018-11-12 14:59:37 +000054 auto const& backendRegistry = BackendRegistryInstance();
55 if (!backendRegistry.IsBackendRegistered(backendId))
56 {
57 std::stringstream ss;
58 ss << connectableLayer.GetName() << " is not supported on " << backendId
59 << " because this backend is not registered.";
60
61 outReasonIfUnsupported = ss.str();
62 return false;
63 }
64
65 auto backendFactory = backendRegistry.GetFactory(backendId);
66 auto backendObject = backendFactory();
67 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010068
telsoa014fcda012018-03-09 14:13:49 +000069 switch(layer.GetType())
70 {
71 case LayerType::Activation:
72 {
73 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
74 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010075 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010076 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010077 OverrideDataType(input, dataType),
78 OverrideDataType(output, dataType),
79 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010080 reason);
telsoa014fcda012018-03-09 14:13:49 +000081 break;
82 }
83 case LayerType::Addition:
84 {
85 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
86 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
87 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010088 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010089 OverrideDataType(input0, dataType),
90 OverrideDataType(input1, dataType),
91 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010092 reason);
telsoa014fcda012018-03-09 14:13:49 +000093 break;
94 }
Nikhil Rajee391d52019-09-05 17:50:44 +010095 case LayerType::ArgMinMax:
96 {
97 auto cLayer = boost::polymorphic_downcast<const ArgMinMaxLayer*>(&layer);
98 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
99
100 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
101 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
102 result = layerSupportObject->IsArgMinMaxSupported(
103 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000104 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100105 descriptor,
106 reason);
107 break;
108 }
telsoa014fcda012018-03-09 14:13:49 +0000109 case LayerType::BatchNormalization:
110 {
111 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
112 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100113 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
114 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
115 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
116 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
117 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100118 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100119 OverrideDataType(input, dataType),
120 OverrideDataType(output, dataType),
121 OverrideDataType(mean, dataType),
122 OverrideDataType(var, dataType),
123 OverrideDataType(beta, dataType),
124 OverrideDataType(gamma, dataType),
125 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100126 reason);
telsoa014fcda012018-03-09 14:13:49 +0000127 break;
128 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000129 case LayerType::BatchToSpaceNd:
130 {
131 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
132 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
133 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
134
135 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
136 OverrideDataType(output, dataType),
137 cLayer->GetParameters(),
138 reason);
139 break;
140 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100141 case LayerType::Comparison:
142 {
143 auto cLayer = boost::polymorphic_downcast<const ComparisonLayer*>(&layer);
144
145 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
146 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
148
149 result = layerSupportObject->IsComparisonSupported(OverrideDataType(input0, dataType),
150 OverrideDataType(input1, dataType),
151 OverrideDataType(output, DataType::Boolean),
152 cLayer->GetParameters(),
153 reason);
154 break;
155 }
telsoa014fcda012018-03-09 14:13:49 +0000156 case LayerType::Constant:
157 {
158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100159 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100160 break;
161 }
162 case LayerType::ConvertFp16ToFp32:
163 {
164 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
165 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100166 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100167 break;
168 }
169 case LayerType::ConvertFp32ToFp16:
170 {
171 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
172 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100173 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000174 break;
175 }
176 case LayerType::Convolution2d:
177 {
178 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100179
180 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
181 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100182 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100183 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
184
arovir01a6824102018-08-28 17:40:45 +0100185 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100186
arovir01a6824102018-08-28 17:40:45 +0100187 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100188 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100189 if (descriptor.m_BiasEnabled)
190 {
David Beck5eec11d2018-10-04 15:43:17 +0100191 biases =
192 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100193 }
194
David Beck33f0ae02018-10-18 15:13:56 +0100195 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100196 input,
197 output,
198 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100199 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100200 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100201 reason);
telsoa014fcda012018-03-09 14:13:49 +0000202 break;
203 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000204 case LayerType::Debug:
205 {
206 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
207 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
208
209 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
210 OverrideDataType(output, dataType),
211 reason);
212 break;
213 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100214 case LayerType::DepthToSpace:
215 {
216 auto cLayer = boost::polymorphic_downcast<const DepthToSpaceLayer*>(&layer);
217
218 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
219 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
220
221 result = layerSupportObject->IsDepthToSpaceSupported(OverrideDataType(input, dataType),
222 OverrideDataType(output, dataType),
223 cLayer->GetParameters(),
224 reason);
225 break;
226 }
telsoa014fcda012018-03-09 14:13:49 +0000227 case LayerType::DepthwiseConvolution2d:
228 {
229 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100230 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
231 dataType);
232 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
233 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
234
telsoa01c577f2c2018-08-31 09:22:23 +0100235 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100236
237 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100238 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100239 if (descriptor.m_BiasEnabled)
240 {
David Beck5eec11d2018-10-04 15:43:17 +0100241 biases =
242 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100243 }
telsoa01c577f2c2018-08-31 09:22:23 +0100244
David Beck33f0ae02018-10-18 15:13:56 +0100245 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100246 input,
247 output,
248 descriptor,
249 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100250 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100251 reason);
telsoa014fcda012018-03-09 14:13:49 +0000252 break;
253 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000254 case LayerType::Dequantize:
255 {
256 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
257 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
258
Aron Virginas-Tar87972be2019-11-13 15:16:28 +0000259 result = layerSupportObject->IsDequantizeSupported(input,
260 OverrideDataType(output, dataType),
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000261 reason);
262 break;
263 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000264 case LayerType::DetectionPostProcess:
265 {
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000266 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000267 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
268 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
269 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
270
271 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
272 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
273 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
274 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
275
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000276 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000277 result = layerSupportObject->IsDetectionPostProcessSupported(boxEncodings,
278 scores,
279 anchors,
280 detectionBoxes,
281 detectionClasses,
282 detectionScores,
283 numDetections,
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000284 descriptor,
285 reason);
286 break;
287 }
josh minor4a3c6102020-01-06 16:40:46 -0600288 case LayerType::ElementwiseUnary:
289 {
290 auto cLayer = boost::polymorphic_downcast<const ElementwiseUnaryLayer*>(&layer);
291
292 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
293 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
294
295 result = layerSupportObject->IsElementwiseUnarySupported(OverrideDataType(input, dataType),
296 OverrideDataType(output, dataType),
297 cLayer->GetParameters(),
298 reason);
299 break;
300 }
telsoa014fcda012018-03-09 14:13:49 +0000301 case LayerType::FakeQuantization:
302 {
303 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
304 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100305 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
306 cLayer->GetParameters(),
307 reason);
telsoa014fcda012018-03-09 14:13:49 +0000308 break;
309 }
310 case LayerType::Floor:
311 {
312 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
313 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100314 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
315 OverrideDataType(output, dataType),
316 reason);
telsoa014fcda012018-03-09 14:13:49 +0000317 break;
318 }
319 case LayerType::FullyConnected:
320 {
321 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
322 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100323 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
324 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
325
326 TensorInfo biasInfo;
327 const TensorInfo * biasInfoPtr = nullptr;
328 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
329 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
330 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
331
332 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
333 if (descriptor.m_BiasEnabled)
334 {
335 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
336 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
337 biasInfoPtr = &biasInfo;
338 }
339 else
340 {
341 // If biases are not enabled pass a dummy tensorinfo for the validation
342 switch(input.GetDataType())
343 {
344 case DataType::Float16:
345 {
346 biasInfoPtr = &dummyFloat16Bias;
347 break;
348 }
349 case DataType::Float32:
350 {
351 biasInfoPtr = &dummyFloat32Bias;
352 break;
353 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000354 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000355 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000356 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000357 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100358 {
359 biasInfoPtr = &dummyQA8Bias;
360 break;
361 }
362 default:
363 {
364 BOOST_ASSERT_MSG(false, "Unexpected bias type");
365 }
366 }
367 }
368
David Beck33f0ae02018-10-18 15:13:56 +0100369 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100370 OverrideDataType(input, dataType),
371 OverrideDataType(output, dataType),
372 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
373 *biasInfoPtr,
374 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100375 reason);
telsoa014fcda012018-03-09 14:13:49 +0000376 break;
377 }
narpra01b89b05f2019-01-16 09:53:09 +0000378 case LayerType::Gather:
379 {
380 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
381 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
382 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
383 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100384 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000385 OverrideDataType(output, dataType),
386 reason);
387 break;
388 }
telsoa014fcda012018-03-09 14:13:49 +0000389 case LayerType::Input:
390 {
391 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100392 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000393 break;
394 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100395 case LayerType::InstanceNormalization:
396 {
397 auto cLayer = boost::polymorphic_downcast<const InstanceNormalizationLayer*>(&layer);
398 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
399
400 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
401 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
402
403 result = layerSupportObject->IsInstanceNormalizationSupported(
404 OverrideDataType(input, dataType),
405 OverrideDataType(output, dataType),
406 descriptor,
407 reason);
408 break;
409 }
telsoa014fcda012018-03-09 14:13:49 +0000410 case LayerType::L2Normalization:
411 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100412 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
413 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
414
telsoa014fcda012018-03-09 14:13:49 +0000415 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100416 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100417
David Beck33f0ae02018-10-18 15:13:56 +0100418 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100419 OverrideDataType(input, dataType),
420 OverrideDataType(output, dataType),
421 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100422 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100423 break;
424 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100425 case LayerType::LogSoftmax:
426 {
427 auto cLayer = boost::polymorphic_downcast<const LogSoftmaxLayer*>(&layer);
428
429 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
430 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
431
432 result = layerSupportObject->IsLogSoftmaxSupported(OverrideDataType(input, dataType),
433 OverrideDataType(output, dataType),
434 cLayer->GetParameters(),
435 reason);
436 break;
437 }
telsoa01c577f2c2018-08-31 09:22:23 +0100438 case LayerType::Lstm:
439 {
440 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
441 const LstmDescriptor& descriptor = cLayer->GetParameters();
442
443 // All inputs.
444 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
445 dataType);
446 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
447 dataType);
448 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
449 dataType);
450 // All outputs
451 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
452 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
453 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
454 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
455
456 // Basic parameters
457 const TensorInfo& inputToForgetWeights
458 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
459 const TensorInfo& inputToCellWeights
460 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
461 const TensorInfo& inputToOutputWeights
462 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
463 const TensorInfo& recurrentToForgetWeights
464 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
465 const TensorInfo& recurrentToCellWeights
466 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
467 const TensorInfo& recurrentToOutputWeights
468 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
469 const TensorInfo& forgetGateBias
470 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
471 const TensorInfo& cellBias
472 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
473 const TensorInfo& outputGateBias
474 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
475
Jan Eilersd01a83c2019-07-03 18:20:40 +0100476 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100477
Jan Eilersd01a83c2019-07-03 18:20:40 +0100478 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
479 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
480 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
481 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
482 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
483 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
484 paramsInfo.m_ForgetGateBias = &forgetGateBias;
485 paramsInfo.m_CellBias = &cellBias;
486 paramsInfo.m_OutputGateBias = &outputGateBias;
487
488
489 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100490 TensorInfo optInputToInputWeights;
491 TensorInfo optRecurrentToInputWeights;
492 TensorInfo optCellToInputWeights;
493 TensorInfo optInputGateBias;
494 TensorInfo optProjectionWeights;
495 TensorInfo optProjectionBias;
496 TensorInfo optCellToForgetWeights;
497 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100498 TensorInfo optInputLayerNormWeights;
499 TensorInfo optForgetLayerNormWeights;
500 TensorInfo optCellLayerNormWeights;
501 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100502
503 if(!descriptor.m_CifgEnabled)
504 {
505 optInputToInputWeights =
506 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100507 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100508
509 optRecurrentToInputWeights =
510 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100511 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100512 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
513 {
514 optCellToInputWeights =
515 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100516 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100517 }
518 optInputGateBias =
519 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100520 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100521 }
522
523 if(descriptor.m_ProjectionEnabled)
524 {
525 optProjectionWeights =
526 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100527 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100528 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
529 {
530 optProjectionBias =
531 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100532 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100533 }
534 }
535
536 if(descriptor.m_PeepholeEnabled)
537 {
538 optCellToForgetWeights =
539 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100540 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100541 optCellToOutputWeights =
542 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100543 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100544 }
545
Jan Eilers38e05bd2019-06-26 13:10:09 +0100546 if(descriptor.m_LayerNormEnabled)
547 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100548 if (!descriptor.m_CifgEnabled)
549 {
550 optInputLayerNormWeights = OverrideDataType(
551 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
552 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
553 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100554
555 optForgetLayerNormWeights = OverrideDataType(
556 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100557 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100558
559 optCellLayerNormWeights = OverrideDataType(
560 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100561 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100562
563 optOutputLayerNormWeights = OverrideDataType(
564 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100565 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100566 }
567
David Beck33f0ae02018-10-18 15:13:56 +0100568 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100569 input,
570 outputStateIn,
571 cellStateIn,
572 scratchBuffer,
573 outputStateOut,
574 cellStateOut,
575 output,
576 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100577 paramsInfo,
578 reason);
telsoa014fcda012018-03-09 14:13:49 +0000579 break;
580 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000581 case LayerType::Maximum:
582 {
583 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
584 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
585 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
586
587 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
588 OverrideDataType(input1, dataType),
589 OverrideDataType(output, dataType),
590 reason);
591 break;
592 }
narpra01b89b05f2019-01-16 09:53:09 +0000593 case LayerType::MemCopy:
594 {
595 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
596 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000597
narpra01b89b05f2019-01-16 09:53:09 +0000598 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
599 OverrideDataType(output, dataType),
600 reason);
601 break;
602 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100603 case LayerType::MemImport:
604 {
605 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
606 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
607
608 result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
609 OverrideDataType(output, dataType),
610 reason);
611 break;
612 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100613 case LayerType::Merge:
614 {
615 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
616 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
617 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
618
619 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
620 OverrideDataType(input1, dataType),
621 OverrideDataType(output, dataType),
622 reason);
623 break;
624 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100625 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000626 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100627 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000628
telsoa01c577f2c2018-08-31 09:22:23 +0100629 // Get vector of all inputs.
630 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000631 {
telsoa01c577f2c2018-08-31 09:22:23 +0100632 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000633 };
telsoa01c577f2c2018-08-31 09:22:23 +0100634 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
635 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
636 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000637
telsoa01c577f2c2018-08-31 09:22:23 +0100638 auto getTensorInfoPtr = [](const TensorInfo& info)
639 {
640 return &info;
641 };
642 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
643 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
644 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000645
Nikhil Raj8599a412018-11-19 14:51:07 +0000646 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
647
Jim Flynne242f2d2019-05-22 14:24:13 +0100648 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
649
650
telsoa014fcda012018-03-09 14:13:49 +0000651 break;
652 }
653 case LayerType::Multiplication:
654 {
655 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
656 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100657 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100658 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100659 OverrideDataType(input0, dataType),
660 OverrideDataType(input1, dataType),
661 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100662 reason);
telsoa014fcda012018-03-09 14:13:49 +0000663 break;
664 }
665 case LayerType::Normalization:
666 {
667 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
668 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
669 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100670 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
671 OverrideDataType(output, dataType),
672 cLayer->GetParameters(),
673 reason);
telsoa014fcda012018-03-09 14:13:49 +0000674 break;
675 }
676 case LayerType::Output:
677 {
678 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100679 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000680 break;
681 }
682 case LayerType::Permute:
683 {
684 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
685 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
686 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100687 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
688 OverrideDataType(output, dataType),
689 cLayer->GetParameters(),
690 reason);
telsoa014fcda012018-03-09 14:13:49 +0000691 break;
692 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100693 case LayerType::Pad:
694 {
695 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
696 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
697 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100698 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100699 OverrideDataType(input, dataType),
700 OverrideDataType(output, dataType),
701 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100702 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100703 break;
704 }
telsoa014fcda012018-03-09 14:13:49 +0000705 case LayerType::Pooling2d:
706 {
707 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
708 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
709 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100710 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
711 OverrideDataType(output, dataType),
712 cLayer->GetParameters(),
713 reason);
telsoa014fcda012018-03-09 14:13:49 +0000714 break;
715 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000716 case LayerType::PreCompiled:
717 {
718 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
719 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
720 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
721 cLayer->GetParameters(),
722 reason);
723 break;
724 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000725 case LayerType::Quantize:
726 {
727 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
728 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
729 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
730 break;
731 }
James Conroyee18dc82019-07-17 11:27:46 +0100732 case LayerType::QuantizedLstm:
733 {
734 auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer);
735
736 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100737 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
738 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
739 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100740
741 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100742 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
743 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100744
745 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100746 QuantizedLstmInputParamsInfo paramsInfo;
747
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100748 paramsInfo.m_InputToInputWeights =
749 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
750 paramsInfo.m_InputToForgetWeights =
751 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
752 paramsInfo.m_InputToCellWeights =
753 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
754 paramsInfo.m_InputToOutputWeights =
755 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100756
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100757 paramsInfo.m_RecurrentToInputWeights =
758 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
759 paramsInfo.m_RecurrentToForgetWeights =
760 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
761 paramsInfo.m_RecurrentToCellWeights =
762 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
763 paramsInfo.m_RecurrentToOutputWeights =
764 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100765
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100766 paramsInfo.m_InputGateBias =
767 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
768 paramsInfo.m_ForgetGateBias =
769 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
770 paramsInfo.m_CellBias =
771 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
772 paramsInfo.m_OutputGateBias =
773 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100774
775 result = layerSupportObject->IsQuantizedLstmSupported(input,
776 previousCellStateIn,
777 previousOutputIn,
778 cellStateOut,
779 output,
780 paramsInfo,
781 reason);
782 break;
783 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100784 case LayerType::Division:
785 {
786 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
787 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
788 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100789 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100790 OverrideDataType(input0, dataType),
791 OverrideDataType(input1, dataType),
792 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100793 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100794 break;
795 }
telsoa014fcda012018-03-09 14:13:49 +0000796 case LayerType::Reshape:
797 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000798 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000799 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +0000800 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000801 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
Kevin Maya023c402019-12-12 17:28:05 +0000802 OverrideDataType(output, dataType),
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000803 cLayer->GetParameters(),
804 reason);
telsoa014fcda012018-03-09 14:13:49 +0000805 break;
806 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100807 case LayerType::Resize:
808 {
809 auto cLayer = boost::polymorphic_downcast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100810 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100811 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
812 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
813 OverrideDataType(output, dataType),
814 cLayer->GetParameters(),
815 reason);
816 break;
817 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100818 case LayerType::Slice:
819 {
820 auto cLayer = boost::polymorphic_downcast<const SliceLayer*>(&layer);
821
822 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
823 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
824
825 result = layerSupportObject->IsSliceSupported(OverrideDataType(input, dataType),
826 OverrideDataType(output, dataType),
827 cLayer->GetParameters(),
828 reason);
829 break;
830 }
telsoa014fcda012018-03-09 14:13:49 +0000831 case LayerType::Softmax:
832 {
833 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
834 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100835 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100836 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
837 OverrideDataType(output, dataType),
838 cLayer->GetParameters(),
839 reason);
telsoa014fcda012018-03-09 14:13:49 +0000840 break;
841 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000842 case LayerType::SpaceToBatchNd:
843 {
844 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
845 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
846 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
847 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
848 OverrideDataType(output, dataType),
849 cLayer->GetParameters(),
850 reason);
851 break;
852 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100853 case LayerType::SpaceToDepth:
854 {
855 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
856
857 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
858 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
859
860 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
861 OverrideDataType(output, dataType),
862 cLayer->GetParameters(),
863 reason);
864 break;
865 }
telsoa014fcda012018-03-09 14:13:49 +0000866 case LayerType::Splitter:
867 {
868 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
869 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100870
871 // Get vector of all outputs.
872 auto getTensorInfo = [&dataType](const OutputSlot& slot)
873 {
874 return OverrideDataType(slot.GetTensorInfo(), dataType);
875 };
876 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
877 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
878 std::vector<TensorInfo> outputs(beginI, endI);
879
880 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
881
David Beck33f0ae02018-10-18 15:13:56 +0100882 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100883 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100884 cLayer->GetParameters(),
885 reason);
telsoa014fcda012018-03-09 14:13:49 +0000886 break;
887 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100888 case LayerType::Stack:
889 {
890 auto cLayer = boost::polymorphic_downcast<const StackLayer*>(&layer);
891
892 // Get vector of all inputs.
893 auto getTensorInfo = [&dataType](const InputSlot& slot)
894 {
895 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
896 };
897 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
898 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
899 std::vector<TensorInfo> inputs(beginI, endI);
900
901 auto getTensorInfoPtr = [](const TensorInfo& info)
902 {
903 return &info;
904 };
905 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
906 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
907 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
908
909 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
910
911 result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
912
913 break;
914 }
Derek Lamberti013c3902019-10-21 10:46:16 +0100915 case LayerType::StandIn:
916 {
917 auto cLayer = boost::polymorphic_downcast<const StandInLayer*>(&layer);
918
919 // Get vector of all inputs.
920 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
921 {
922 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
923 };
924 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
925 {
926 return OverrideDataType(slot.GetTensorInfo(), dataType);
927 };
928 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfoIn);
929 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfoIn);
930 std::vector<TensorInfo> inputs(beginI, endI);
931
932 auto beginO = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
933 auto endO = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfoOut);
934 std::vector<TensorInfo> outputs(beginO, endO);
935
936
937 auto getTensorInfoPtr = [](const TensorInfo& info)
938 {
939 return &info;
940 };
941 auto beginPtrI = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
942 auto endPtrI = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
943 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
944
945 auto beginPtrO = boost::make_transform_iterator(outputs.begin(), getTensorInfoPtr);
946 auto endPtrO = boost::make_transform_iterator(outputs.end(), getTensorInfoPtr);
947 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
948
949
950 result = layerSupportObject->IsStandInSupported(inputPtrs,
951 outputPtrs,
952 cLayer->GetParameters(),
953 reason);
954 break;
955 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000956 case LayerType::StridedSlice:
957 {
958 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
959 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
960 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
961 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
962 OverrideDataType(output, dataType),
963 cLayer->GetParameters(),
964 reason);
965 break;
966 }
David Beckc2044fe2018-09-05 15:00:38 +0100967 case LayerType::Subtraction:
968 {
969 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
970 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
971 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100972 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100973 OverrideDataType(input0, dataType),
974 OverrideDataType(input1, dataType),
975 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100976 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100977 break;
978 }
Sadik Armaganeff363d2019-04-05 15:25:46 +0100979 case LayerType::Switch:
980 {
981 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
982 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
983 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
984 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
985 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
986 OverrideDataType(input1, dataType),
987 OverrideDataType(output0, dataType),
988 OverrideDataType(output1, dataType),
989 reason);
990 break;
991 }
narpra0132b90462018-09-13 11:07:48 +0100992 case LayerType::Mean:
993 {
994 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
995 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
996 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100997 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100998 OverrideDataType(input, dataType),
999 OverrideDataType(output, dataType),
1000 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001001 reason);
narpra0132b90462018-09-13 11:07:48 +01001002 break;
1003 }
kevmay0190539692018-11-29 08:40:19 +00001004 case LayerType::Minimum:
1005 {
1006 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1007 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1008 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1009 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
1010 OverrideDataType(input1, dataType),
1011 OverrideDataType(output, dataType),
1012 reason);
1013 break;
1014 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001015 case LayerType::Prelu:
1016 {
1017 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1018 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1019 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1020 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
1021 OverrideDataType(alpha, dataType),
1022 OverrideDataType(output, dataType),
1023 reason);
1024 break;
1025 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001026 case LayerType::TransposeConvolution2d:
1027 {
1028 auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
1029
1030 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1031 dataType);
1032 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1033
1034 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1035
1036 Optional<TensorInfo> biases;
1037 if (descriptor.m_BiasEnabled)
1038 {
1039 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
1040 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1041 GetBiasTypeFromWeightsType(dataType));
1042 }
1043
1044 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
1045 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1046
1047 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
1048 output,
1049 descriptor,
1050 weights,
1051 biases,
1052 reason);
1053
1054 break;
1055 }
telsoa014fcda012018-03-09 14:13:49 +00001056 default:
1057 {
1058 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001059 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001060 result = false;
1061 break;
1062 }
1063 }
telsoa014fcda012018-03-09 14:13:49 +00001064 return result;
1065}
1066
David Beckdcb751f2018-10-03 11:42:42 +01001067bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001068 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001069 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001070{
David Beckdcb751f2018-10-03 11:42:42 +01001071 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +01001072 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +00001073}
1074
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001075// Default Implementations
Derek Lamberti901ea112019-12-10 22:07:09 +00001076std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& /*descriptor*/,
1077 const WorkloadInfo& /*info*/) const
Kevin May868eb142019-09-04 17:29:31 +01001078{
1079 return std::unique_ptr<IWorkload>();
1080}
1081
Derek Lamberti901ea112019-12-10 22:07:09 +00001082std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1083 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001084{
1085 return std::unique_ptr<IWorkload>();
1086}
1087
Derek Lamberti901ea112019-12-10 22:07:09 +00001088std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1089 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001090{
1091 return std::unique_ptr<IWorkload>();
1092}
1093
Derek Lamberti901ea112019-12-10 22:07:09 +00001094std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1095 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001096{
1097 return std::unique_ptr<IWorkload>();
1098}
1099
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001100std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001101 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001102{
1103 return std::unique_ptr<IWorkload>();
1104}
1105
Derek Lamberti901ea112019-12-10 22:07:09 +00001106std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1107 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001108{
1109 return std::unique_ptr<IWorkload>();
1110}
1111
Derek Lamberti901ea112019-12-10 22:07:09 +00001112std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1113 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001114{
1115 return std::unique_ptr<IWorkload>();
1116}
1117
Derek Lamberti901ea112019-12-10 22:07:09 +00001118std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1119 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001120{
1121 return std::unique_ptr<IWorkload>();
1122}
1123
Derek Lamberti901ea112019-12-10 22:07:09 +00001124std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1125 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001126{
1127 return std::unique_ptr<IWorkload>();
1128}
1129
Derek Lamberti901ea112019-12-10 22:07:09 +00001130std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1131 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001132{
1133 return std::unique_ptr<IWorkload>();
1134}
1135
Derek Lamberti901ea112019-12-10 22:07:09 +00001136std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1137 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001138{
1139 return std::unique_ptr<IWorkload>();
1140}
1141
Derek Lamberti901ea112019-12-10 22:07:09 +00001142std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1143 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001144{
1145 return std::unique_ptr<IWorkload>();
1146}
1147
Derek Lamberti901ea112019-12-10 22:07:09 +00001148std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1149 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001150{
1151 return std::unique_ptr<IWorkload>();
1152}
1153
Derek Lamberti901ea112019-12-10 22:07:09 +00001154std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1155 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001156{
1157 return std::unique_ptr<IWorkload>();
1158}
1159
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001160std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001161 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001162{
1163 return std::unique_ptr<IWorkload>();
1164}
1165
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001166std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001167 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001168{
1169 return std::unique_ptr<IWorkload>();
1170}
1171
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001172std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001173 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001174{
1175 return std::unique_ptr<IWorkload>();
1176}
1177
Derek Lamberti901ea112019-12-10 22:07:09 +00001178std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1179 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001180{
1181 return std::unique_ptr<IWorkload>();
1182}
1183
josh minor4a3c6102020-01-06 16:40:46 -06001184std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1185 const WorkloadInfo& /*info*/) const
1186{
1187 return std::unique_ptr<IWorkload>();
1188}
1189
Derek Lamberti901ea112019-12-10 22:07:09 +00001190std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& /*descriptor*/,
1191 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001192{
1193 return std::unique_ptr<IWorkload>();
1194}
1195
Derek Lamberti901ea112019-12-10 22:07:09 +00001196std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1197 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001198{
1199 return std::unique_ptr<IWorkload>();
1200}
1201
Derek Lamberti901ea112019-12-10 22:07:09 +00001202std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1203 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001204{
1205 return std::unique_ptr<IWorkload>();
1206}
1207
Derek Lamberti901ea112019-12-10 22:07:09 +00001208std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1209 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001210{
1211 return std::unique_ptr<IWorkload>();
1212}
1213
Derek Lamberti901ea112019-12-10 22:07:09 +00001214std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1215 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001216{
1217 return std::unique_ptr<IWorkload>();
1218}
1219
Derek Lamberti901ea112019-12-10 22:07:09 +00001220std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& /*descriptor*/,
1221 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001222{
1223 return std::unique_ptr<IWorkload>();
1224}
1225
Kevin Mayce5045a2019-10-02 14:07:47 +01001226std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001227 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1228 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001229{
1230 return std::unique_ptr<IWorkload>();
1231}
1232
Derek Lamberti901ea112019-12-10 22:07:09 +00001233std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1234 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001235{
1236 return std::unique_ptr<IWorkload>();
1237}
1238
Derek Lamberti901ea112019-12-10 22:07:09 +00001239std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1240 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001241{
1242 return std::unique_ptr<IWorkload>();
1243}
1244
Derek Lamberti901ea112019-12-10 22:07:09 +00001245std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1246 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001247{
1248 return std::unique_ptr<IWorkload>();
1249}
1250
Derek Lamberti901ea112019-12-10 22:07:09 +00001251std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1252 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001253{
1254 return std::unique_ptr<IWorkload>();
1255}
1256
Derek Lamberti901ea112019-12-10 22:07:09 +00001257std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1258 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001259{
1260 return std::unique_ptr<IWorkload>();
1261}
1262
Derek Lamberti901ea112019-12-10 22:07:09 +00001263std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1264 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001265{
1266 return std::unique_ptr<IWorkload>();
1267}
1268
Derek Lamberti901ea112019-12-10 22:07:09 +00001269std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1270 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001271{
1272 return std::unique_ptr<IWorkload>();
1273}
1274
Derek Lamberti901ea112019-12-10 22:07:09 +00001275std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1276 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001277{
1278 return std::unique_ptr<IWorkload>();
1279}
1280
Derek Lamberti901ea112019-12-10 22:07:09 +00001281std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& /*descriptor*/,
1282 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001283{
1284 return std::unique_ptr<IWorkload>();
1285}
1286
Derek Lamberti901ea112019-12-10 22:07:09 +00001287std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1288 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001289{
1290 return std::unique_ptr<IWorkload>();
1291}
1292
Derek Lamberti901ea112019-12-10 22:07:09 +00001293std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1294 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001295{
1296 return std::unique_ptr<IWorkload>();
1297}
1298
Derek Lamberti901ea112019-12-10 22:07:09 +00001299std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1300 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001301{
1302 return std::unique_ptr<IWorkload>();
1303}
1304
Derek Lamberti901ea112019-12-10 22:07:09 +00001305std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1306 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001307{
1308 return std::unique_ptr<IWorkload>();
1309}
1310
Derek Lamberti901ea112019-12-10 22:07:09 +00001311std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1312 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001313{
1314 return std::unique_ptr<IWorkload>();
1315}
1316
Derek Lamberti901ea112019-12-10 22:07:09 +00001317std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
1318 const WorkloadInfo&/**/ /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001319{
1320 return std::unique_ptr<IWorkload>();
1321}
1322
Derek Lamberti901ea112019-12-10 22:07:09 +00001323std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1324 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001325{
1326 return std::unique_ptr<IWorkload>();
1327}
1328
Derek Lamberti901ea112019-12-10 22:07:09 +00001329std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1330 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001331{
1332 return std::unique_ptr<IWorkload>();
1333}
1334
Derek Lamberti901ea112019-12-10 22:07:09 +00001335std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1336 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001337{
1338 return std::unique_ptr<IWorkload>();
1339}
1340
Derek Lamberti901ea112019-12-10 22:07:09 +00001341std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1342 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001343{
1344 return std::unique_ptr<IWorkload>();
1345}
1346
Derek Lamberti901ea112019-12-10 22:07:09 +00001347std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1348 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001349{
1350 return std::unique_ptr<IWorkload>();
1351}
1352
Derek Lamberti901ea112019-12-10 22:07:09 +00001353std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1354 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001355{
1356 return std::unique_ptr<IWorkload>();
1357}
1358
Derek Lamberti901ea112019-12-10 22:07:09 +00001359std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& /*descriptor*/,
1360 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001361{
1362 return std::unique_ptr<IWorkload>();
1363}
1364
Derek Lamberti901ea112019-12-10 22:07:09 +00001365std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1366 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001367{
1368 return std::unique_ptr<IWorkload>();
1369}
1370
Derek Lamberti901ea112019-12-10 22:07:09 +00001371std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& /*descriptor*/,
1372 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001373{
1374 return std::unique_ptr<IWorkload>();
1375}
1376
Derek Lamberti901ea112019-12-10 22:07:09 +00001377std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1378 const WorkloadInfo& /*info*/) const
1379{
1380 return std::unique_ptr<IWorkload>();
1381}
1382/**/
1383std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1384 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001385{
1386 return std::unique_ptr<IWorkload>();
1387}
1388
Derek Lamberti901ea112019-12-10 22:07:09 +00001389std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1390 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001391{
1392 return std::unique_ptr<IWorkload>();
1393}
1394
Derek Lamberti901ea112019-12-10 22:07:09 +00001395std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1396 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001397{
1398 return std::unique_ptr<IWorkload>();
1399}
1400
Derek Lamberti901ea112019-12-10 22:07:09 +00001401std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1402 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001403{
1404 return std::unique_ptr<IWorkload>();
1405}
1406
Derek Lamberti901ea112019-12-10 22:07:09 +00001407std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1408 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001409{
1410 return std::unique_ptr<IWorkload>();
1411}
1412
Derek Lamberti901ea112019-12-10 22:07:09 +00001413std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1414 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001415{
1416 return std::unique_ptr<IWorkload>();
1417}
1418
Derek Lamberti901ea112019-12-10 22:07:09 +00001419std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1420 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001421{
1422 return std::unique_ptr<IWorkload>();
1423}
1424
Derek Lamberti901ea112019-12-10 22:07:09 +00001425std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1426 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001427{
1428 return std::unique_ptr<IWorkload>();
1429}
1430
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001431std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001432 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1433 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001434{
1435 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001436}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001437
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001438} // namepsace armnn