blob: 805ec7ba5ff0205b445793bbaf8452ea9bcaf47d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000011#include <armnn/ILayerSupport.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000012#include <armnn/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <backendsCommon/WorkloadFactory.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000015#include <armnn/backends/IBackendInternal.hpp>
16#include <backendsCommon/CpuTensorHandle.hpp>
17#include <backendsCommon/WorkloadFactory.hpp>
18
Francis Murtagh46c09d02019-05-28 08:15:28 +010019#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020
21#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000022#include <boost/iterator/transform_iterator.hpp>
23
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000025#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000026
telsoa014fcda012018-03-09 14:13:49 +000027namespace armnn
28{
29
telsoa01c577f2c2018-08-31 09:22:23 +010030namespace
31{
telsoa01c577f2c2018-08-31 09:22:23 +010032
David Beck29c75de2018-10-23 13:35:58 +010033const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
34{
35 if (!type)
36 {
37 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010038 }
39
David Beck29c75de2018-10-23 13:35:58 +010040 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010041}
42
David Beck29c75de2018-10-23 13:35:58 +010043} // anonymous namespace
44
David Beck33f0ae02018-10-18 15:13:56 +010045bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010046 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010047 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010048 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000049{
David Beck33f0ae02018-10-18 15:13:56 +010050 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000051 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010052 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
53
David Beck111b5d92018-11-12 14:59:37 +000054 auto const& backendRegistry = BackendRegistryInstance();
55 if (!backendRegistry.IsBackendRegistered(backendId))
56 {
57 std::stringstream ss;
58 ss << connectableLayer.GetName() << " is not supported on " << backendId
59 << " because this backend is not registered.";
60
61 outReasonIfUnsupported = ss.str();
62 return false;
63 }
64
65 auto backendFactory = backendRegistry.GetFactory(backendId);
66 auto backendObject = backendFactory();
67 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010068
telsoa014fcda012018-03-09 14:13:49 +000069 switch(layer.GetType())
70 {
Kevin May868eb142019-09-04 17:29:31 +010071 case LayerType::Abs:
72 {
73 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
74 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
75 result = layerSupportObject->IsAbsSupported(OverrideDataType(input, dataType),
76 OverrideDataType(output, dataType),
77 reason);
78 break;
79 }
telsoa014fcda012018-03-09 14:13:49 +000080 case LayerType::Activation:
81 {
82 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
83 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010084 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010085 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010086 OverrideDataType(input, dataType),
87 OverrideDataType(output, dataType),
88 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010089 reason);
telsoa014fcda012018-03-09 14:13:49 +000090 break;
91 }
92 case LayerType::Addition:
93 {
94 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
95 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
96 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010097 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010098 OverrideDataType(input0, dataType),
99 OverrideDataType(input1, dataType),
100 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100101 reason);
telsoa014fcda012018-03-09 14:13:49 +0000102 break;
103 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100104 case LayerType::ArgMinMax:
105 {
106 auto cLayer = boost::polymorphic_downcast<const ArgMinMaxLayer*>(&layer);
107 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
108
109 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
110 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
111 result = layerSupportObject->IsArgMinMaxSupported(
112 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000113 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100114 descriptor,
115 reason);
116 break;
117 }
telsoa014fcda012018-03-09 14:13:49 +0000118 case LayerType::BatchNormalization:
119 {
120 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
121 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100122 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
123 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
124 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
125 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
126 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100127 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100128 OverrideDataType(input, dataType),
129 OverrideDataType(output, dataType),
130 OverrideDataType(mean, dataType),
131 OverrideDataType(var, dataType),
132 OverrideDataType(beta, dataType),
133 OverrideDataType(gamma, dataType),
134 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100135 reason);
telsoa014fcda012018-03-09 14:13:49 +0000136 break;
137 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000138 case LayerType::BatchToSpaceNd:
139 {
140 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
141 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
142 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
143
144 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
145 OverrideDataType(output, dataType),
146 cLayer->GetParameters(),
147 reason);
148 break;
149 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100150 case LayerType::Comparison:
151 {
152 auto cLayer = boost::polymorphic_downcast<const ComparisonLayer*>(&layer);
153
154 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
155 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
156 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
157
158 result = layerSupportObject->IsComparisonSupported(OverrideDataType(input0, dataType),
159 OverrideDataType(input1, dataType),
160 OverrideDataType(output, DataType::Boolean),
161 cLayer->GetParameters(),
162 reason);
163 break;
164 }
telsoa014fcda012018-03-09 14:13:49 +0000165 case LayerType::Constant:
166 {
167 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100168 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100169 break;
170 }
171 case LayerType::ConvertFp16ToFp32:
172 {
173 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
174 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100175 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100176 break;
177 }
178 case LayerType::ConvertFp32ToFp16:
179 {
180 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
181 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100182 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000183 break;
184 }
185 case LayerType::Convolution2d:
186 {
187 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100188
189 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
190 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100191 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100192 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
193
arovir01a6824102018-08-28 17:40:45 +0100194 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100195
arovir01a6824102018-08-28 17:40:45 +0100196 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100197 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100198 if (descriptor.m_BiasEnabled)
199 {
David Beck5eec11d2018-10-04 15:43:17 +0100200 biases =
201 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100202 }
203
David Beck33f0ae02018-10-18 15:13:56 +0100204 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100205 input,
206 output,
207 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100208 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100209 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100210 reason);
telsoa014fcda012018-03-09 14:13:49 +0000211 break;
212 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000213 case LayerType::Debug:
214 {
215 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
216 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
217
218 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
219 OverrideDataType(output, dataType),
220 reason);
221 break;
222 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100223 case LayerType::DepthToSpace:
224 {
225 auto cLayer = boost::polymorphic_downcast<const DepthToSpaceLayer*>(&layer);
226
227 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
228 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
229
230 result = layerSupportObject->IsDepthToSpaceSupported(OverrideDataType(input, dataType),
231 OverrideDataType(output, dataType),
232 cLayer->GetParameters(),
233 reason);
234 break;
235 }
telsoa014fcda012018-03-09 14:13:49 +0000236 case LayerType::DepthwiseConvolution2d:
237 {
238 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100239 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
240 dataType);
241 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
242 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
243
telsoa01c577f2c2018-08-31 09:22:23 +0100244 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100245
246 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100247 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100248 if (descriptor.m_BiasEnabled)
249 {
David Beck5eec11d2018-10-04 15:43:17 +0100250 biases =
251 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100252 }
telsoa01c577f2c2018-08-31 09:22:23 +0100253
David Beck33f0ae02018-10-18 15:13:56 +0100254 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100255 input,
256 output,
257 descriptor,
258 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100259 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100260 reason);
telsoa014fcda012018-03-09 14:13:49 +0000261 break;
262 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000263 case LayerType::Dequantize:
264 {
265 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
266 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
267
Aron Virginas-Tar87972be2019-11-13 15:16:28 +0000268 result = layerSupportObject->IsDequantizeSupported(input,
269 OverrideDataType(output, dataType),
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000270 reason);
271 break;
272 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000273 case LayerType::DetectionPostProcess:
274 {
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000275 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000276 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
277 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
278 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
279
280 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
281 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
282 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
283 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
284
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000285 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000286 result = layerSupportObject->IsDetectionPostProcessSupported(boxEncodings,
287 scores,
288 anchors,
289 detectionBoxes,
290 detectionClasses,
291 detectionScores,
292 numDetections,
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000293 descriptor,
294 reason);
295 break;
296 }
telsoa014fcda012018-03-09 14:13:49 +0000297 case LayerType::FakeQuantization:
298 {
299 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
300 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100301 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
302 cLayer->GetParameters(),
303 reason);
telsoa014fcda012018-03-09 14:13:49 +0000304 break;
305 }
306 case LayerType::Floor:
307 {
308 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
309 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100310 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
311 OverrideDataType(output, dataType),
312 reason);
telsoa014fcda012018-03-09 14:13:49 +0000313 break;
314 }
315 case LayerType::FullyConnected:
316 {
317 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
318 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100319 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
320 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
321
322 TensorInfo biasInfo;
323 const TensorInfo * biasInfoPtr = nullptr;
324 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
325 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
326 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
327
328 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
329 if (descriptor.m_BiasEnabled)
330 {
331 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
332 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
333 biasInfoPtr = &biasInfo;
334 }
335 else
336 {
337 // If biases are not enabled pass a dummy tensorinfo for the validation
338 switch(input.GetDataType())
339 {
340 case DataType::Float16:
341 {
342 biasInfoPtr = &dummyFloat16Bias;
343 break;
344 }
345 case DataType::Float32:
346 {
347 biasInfoPtr = &dummyFloat32Bias;
348 break;
349 }
350 case DataType::QuantisedAsymm8:
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100351 case DataType::QuantisedSymm16:
telsoa01c577f2c2018-08-31 09:22:23 +0100352 {
353 biasInfoPtr = &dummyQA8Bias;
354 break;
355 }
356 default:
357 {
358 BOOST_ASSERT_MSG(false, "Unexpected bias type");
359 }
360 }
361 }
362
David Beck33f0ae02018-10-18 15:13:56 +0100363 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100364 OverrideDataType(input, dataType),
365 OverrideDataType(output, dataType),
366 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
367 *biasInfoPtr,
368 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100369 reason);
telsoa014fcda012018-03-09 14:13:49 +0000370 break;
371 }
narpra01b89b05f2019-01-16 09:53:09 +0000372 case LayerType::Gather:
373 {
374 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
375 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
376 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
377 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100378 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000379 OverrideDataType(output, dataType),
380 reason);
381 break;
382 }
telsoa014fcda012018-03-09 14:13:49 +0000383 case LayerType::Input:
384 {
385 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100386 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000387 break;
388 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100389 case LayerType::InstanceNormalization:
390 {
391 auto cLayer = boost::polymorphic_downcast<const InstanceNormalizationLayer*>(&layer);
392 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
393
394 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
395 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
396
397 result = layerSupportObject->IsInstanceNormalizationSupported(
398 OverrideDataType(input, dataType),
399 OverrideDataType(output, dataType),
400 descriptor,
401 reason);
402 break;
403 }
telsoa014fcda012018-03-09 14:13:49 +0000404 case LayerType::L2Normalization:
405 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100406 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
407 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
408
telsoa014fcda012018-03-09 14:13:49 +0000409 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100410 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100411
David Beck33f0ae02018-10-18 15:13:56 +0100412 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100413 OverrideDataType(input, dataType),
414 OverrideDataType(output, dataType),
415 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100416 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100417 break;
418 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100419 case LayerType::LogSoftmax:
420 {
421 auto cLayer = boost::polymorphic_downcast<const LogSoftmaxLayer*>(&layer);
422
423 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
424 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
425
426 result = layerSupportObject->IsLogSoftmaxSupported(OverrideDataType(input, dataType),
427 OverrideDataType(output, dataType),
428 cLayer->GetParameters(),
429 reason);
430 break;
431 }
telsoa01c577f2c2018-08-31 09:22:23 +0100432 case LayerType::Lstm:
433 {
434 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
435 const LstmDescriptor& descriptor = cLayer->GetParameters();
436
437 // All inputs.
438 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
439 dataType);
440 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
441 dataType);
442 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
443 dataType);
444 // All outputs
445 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
446 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
447 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
448 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
449
450 // Basic parameters
451 const TensorInfo& inputToForgetWeights
452 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
453 const TensorInfo& inputToCellWeights
454 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
455 const TensorInfo& inputToOutputWeights
456 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
457 const TensorInfo& recurrentToForgetWeights
458 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
459 const TensorInfo& recurrentToCellWeights
460 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
461 const TensorInfo& recurrentToOutputWeights
462 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
463 const TensorInfo& forgetGateBias
464 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
465 const TensorInfo& cellBias
466 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
467 const TensorInfo& outputGateBias
468 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
469
Jan Eilersd01a83c2019-07-03 18:20:40 +0100470 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100471
Jan Eilersd01a83c2019-07-03 18:20:40 +0100472 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
473 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
474 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
475 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
476 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
477 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
478 paramsInfo.m_ForgetGateBias = &forgetGateBias;
479 paramsInfo.m_CellBias = &cellBias;
480 paramsInfo.m_OutputGateBias = &outputGateBias;
481
482
483 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100484 TensorInfo optInputToInputWeights;
485 TensorInfo optRecurrentToInputWeights;
486 TensorInfo optCellToInputWeights;
487 TensorInfo optInputGateBias;
488 TensorInfo optProjectionWeights;
489 TensorInfo optProjectionBias;
490 TensorInfo optCellToForgetWeights;
491 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100492 TensorInfo optInputLayerNormWeights;
493 TensorInfo optForgetLayerNormWeights;
494 TensorInfo optCellLayerNormWeights;
495 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100496
497 if(!descriptor.m_CifgEnabled)
498 {
499 optInputToInputWeights =
500 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100501 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100502
503 optRecurrentToInputWeights =
504 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100505 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100506 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
507 {
508 optCellToInputWeights =
509 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100510 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100511 }
512 optInputGateBias =
513 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100514 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100515 }
516
517 if(descriptor.m_ProjectionEnabled)
518 {
519 optProjectionWeights =
520 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100521 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100522 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
523 {
524 optProjectionBias =
525 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100526 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100527 }
528 }
529
530 if(descriptor.m_PeepholeEnabled)
531 {
532 optCellToForgetWeights =
533 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100534 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100535 optCellToOutputWeights =
536 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100537 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100538 }
539
Jan Eilers38e05bd2019-06-26 13:10:09 +0100540 if(descriptor.m_LayerNormEnabled)
541 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100542 if (!descriptor.m_CifgEnabled)
543 {
544 optInputLayerNormWeights = OverrideDataType(
545 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
546 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
547 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100548
549 optForgetLayerNormWeights = OverrideDataType(
550 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100551 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100552
553 optCellLayerNormWeights = OverrideDataType(
554 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100555 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100556
557 optOutputLayerNormWeights = OverrideDataType(
558 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100559 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100560 }
561
David Beck33f0ae02018-10-18 15:13:56 +0100562 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100563 input,
564 outputStateIn,
565 cellStateIn,
566 scratchBuffer,
567 outputStateOut,
568 cellStateOut,
569 output,
570 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100571 paramsInfo,
572 reason);
telsoa014fcda012018-03-09 14:13:49 +0000573 break;
574 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000575 case LayerType::Maximum:
576 {
577 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
578 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
579 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
580
581 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
582 OverrideDataType(input1, dataType),
583 OverrideDataType(output, dataType),
584 reason);
585 break;
586 }
narpra01b89b05f2019-01-16 09:53:09 +0000587 case LayerType::MemCopy:
588 {
589 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
590 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000591
narpra01b89b05f2019-01-16 09:53:09 +0000592 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
593 OverrideDataType(output, dataType),
594 reason);
595 break;
596 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100597 case LayerType::MemImport:
598 {
599 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
600 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
601
602 result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
603 OverrideDataType(output, dataType),
604 reason);
605 break;
606 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100607 case LayerType::Merge:
608 {
609 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
610 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
611 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
612
613 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
614 OverrideDataType(input1, dataType),
615 OverrideDataType(output, dataType),
616 reason);
617 break;
618 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100619 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000620 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100621 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000622
telsoa01c577f2c2018-08-31 09:22:23 +0100623 // Get vector of all inputs.
624 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000625 {
telsoa01c577f2c2018-08-31 09:22:23 +0100626 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000627 };
telsoa01c577f2c2018-08-31 09:22:23 +0100628 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
629 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
630 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000631
telsoa01c577f2c2018-08-31 09:22:23 +0100632 auto getTensorInfoPtr = [](const TensorInfo& info)
633 {
634 return &info;
635 };
636 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
637 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
638 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000639
Nikhil Raj8599a412018-11-19 14:51:07 +0000640 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
641
Jim Flynne242f2d2019-05-22 14:24:13 +0100642 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
643
644
telsoa014fcda012018-03-09 14:13:49 +0000645 break;
646 }
647 case LayerType::Multiplication:
648 {
649 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
650 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100651 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100652 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100653 OverrideDataType(input0, dataType),
654 OverrideDataType(input1, dataType),
655 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100656 reason);
telsoa014fcda012018-03-09 14:13:49 +0000657 break;
658 }
659 case LayerType::Normalization:
660 {
661 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
662 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
663 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100664 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
665 OverrideDataType(output, dataType),
666 cLayer->GetParameters(),
667 reason);
telsoa014fcda012018-03-09 14:13:49 +0000668 break;
669 }
670 case LayerType::Output:
671 {
672 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100673 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000674 break;
675 }
676 case LayerType::Permute:
677 {
678 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
679 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
680 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100681 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
682 OverrideDataType(output, dataType),
683 cLayer->GetParameters(),
684 reason);
telsoa014fcda012018-03-09 14:13:49 +0000685 break;
686 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100687 case LayerType::Pad:
688 {
689 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
690 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
691 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100692 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100693 OverrideDataType(input, dataType),
694 OverrideDataType(output, dataType),
695 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100696 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100697 break;
698 }
telsoa014fcda012018-03-09 14:13:49 +0000699 case LayerType::Pooling2d:
700 {
701 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
702 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
703 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100704 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
705 OverrideDataType(output, dataType),
706 cLayer->GetParameters(),
707 reason);
telsoa014fcda012018-03-09 14:13:49 +0000708 break;
709 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000710 case LayerType::PreCompiled:
711 {
712 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
713 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
714 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
715 cLayer->GetParameters(),
716 reason);
717 break;
718 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000719 case LayerType::Quantize:
720 {
721 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
722 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
723 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
724 break;
725 }
James Conroyee18dc82019-07-17 11:27:46 +0100726 case LayerType::QuantizedLstm:
727 {
728 auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer);
729
730 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100731 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
732 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
733 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100734
735 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100736 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
737 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100738
739 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100740 QuantizedLstmInputParamsInfo paramsInfo;
741
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100742 paramsInfo.m_InputToInputWeights =
743 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
744 paramsInfo.m_InputToForgetWeights =
745 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
746 paramsInfo.m_InputToCellWeights =
747 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
748 paramsInfo.m_InputToOutputWeights =
749 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100750
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100751 paramsInfo.m_RecurrentToInputWeights =
752 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
753 paramsInfo.m_RecurrentToForgetWeights =
754 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
755 paramsInfo.m_RecurrentToCellWeights =
756 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
757 paramsInfo.m_RecurrentToOutputWeights =
758 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100759
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100760 paramsInfo.m_InputGateBias =
761 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
762 paramsInfo.m_ForgetGateBias =
763 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
764 paramsInfo.m_CellBias =
765 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
766 paramsInfo.m_OutputGateBias =
767 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100768
769 result = layerSupportObject->IsQuantizedLstmSupported(input,
770 previousCellStateIn,
771 previousOutputIn,
772 cellStateOut,
773 output,
774 paramsInfo,
775 reason);
776 break;
777 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100778 case LayerType::Division:
779 {
780 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
781 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
782 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100783 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100784 OverrideDataType(input0, dataType),
785 OverrideDataType(input1, dataType),
786 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100787 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100788 break;
789 }
telsoa014fcda012018-03-09 14:13:49 +0000790 case LayerType::Reshape:
791 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000792 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000793 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000794 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
795 cLayer->GetParameters(),
796 reason);
telsoa014fcda012018-03-09 14:13:49 +0000797 break;
798 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100799 case LayerType::Resize:
800 {
801 auto cLayer = boost::polymorphic_downcast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100802 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100803 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
804 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
805 OverrideDataType(output, dataType),
806 cLayer->GetParameters(),
807 reason);
808 break;
809 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000810 case LayerType::Rsqrt:
811 {
812 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
813 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
814 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
815 OverrideDataType(output, dataType),
816 reason);
817 break;
818 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100819 case LayerType::Slice:
820 {
821 auto cLayer = boost::polymorphic_downcast<const SliceLayer*>(&layer);
822
823 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
824 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
825
826 result = layerSupportObject->IsSliceSupported(OverrideDataType(input, dataType),
827 OverrideDataType(output, dataType),
828 cLayer->GetParameters(),
829 reason);
830 break;
831 }
telsoa014fcda012018-03-09 14:13:49 +0000832 case LayerType::Softmax:
833 {
834 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
835 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100836 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100837 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
838 OverrideDataType(output, dataType),
839 cLayer->GetParameters(),
840 reason);
telsoa014fcda012018-03-09 14:13:49 +0000841 break;
842 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000843 case LayerType::SpaceToBatchNd:
844 {
845 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
846 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
847 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
848 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
849 OverrideDataType(output, dataType),
850 cLayer->GetParameters(),
851 reason);
852 break;
853 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100854 case LayerType::SpaceToDepth:
855 {
856 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
857
858 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
859 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
860
861 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
862 OverrideDataType(output, dataType),
863 cLayer->GetParameters(),
864 reason);
865 break;
866 }
telsoa014fcda012018-03-09 14:13:49 +0000867 case LayerType::Splitter:
868 {
869 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
870 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100871
872 // Get vector of all outputs.
873 auto getTensorInfo = [&dataType](const OutputSlot& slot)
874 {
875 return OverrideDataType(slot.GetTensorInfo(), dataType);
876 };
877 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
878 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
879 std::vector<TensorInfo> outputs(beginI, endI);
880
881 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
882
David Beck33f0ae02018-10-18 15:13:56 +0100883 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100884 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100885 cLayer->GetParameters(),
886 reason);
telsoa014fcda012018-03-09 14:13:49 +0000887 break;
888 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100889 case LayerType::Stack:
890 {
891 auto cLayer = boost::polymorphic_downcast<const StackLayer*>(&layer);
892
893 // Get vector of all inputs.
894 auto getTensorInfo = [&dataType](const InputSlot& slot)
895 {
896 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
897 };
898 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
899 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
900 std::vector<TensorInfo> inputs(beginI, endI);
901
902 auto getTensorInfoPtr = [](const TensorInfo& info)
903 {
904 return &info;
905 };
906 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
907 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
908 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
909
910 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
911
912 result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
913
914 break;
915 }
Derek Lamberti013c3902019-10-21 10:46:16 +0100916 case LayerType::StandIn:
917 {
918 auto cLayer = boost::polymorphic_downcast<const StandInLayer*>(&layer);
919
920 // Get vector of all inputs.
921 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
922 {
923 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
924 };
925 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
926 {
927 return OverrideDataType(slot.GetTensorInfo(), dataType);
928 };
929 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfoIn);
930 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfoIn);
931 std::vector<TensorInfo> inputs(beginI, endI);
932
933 auto beginO = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
934 auto endO = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfoOut);
935 std::vector<TensorInfo> outputs(beginO, endO);
936
937
938 auto getTensorInfoPtr = [](const TensorInfo& info)
939 {
940 return &info;
941 };
942 auto beginPtrI = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
943 auto endPtrI = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
944 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
945
946 auto beginPtrO = boost::make_transform_iterator(outputs.begin(), getTensorInfoPtr);
947 auto endPtrO = boost::make_transform_iterator(outputs.end(), getTensorInfoPtr);
948 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
949
950
951 result = layerSupportObject->IsStandInSupported(inputPtrs,
952 outputPtrs,
953 cLayer->GetParameters(),
954 reason);
955 break;
956 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000957 case LayerType::StridedSlice:
958 {
959 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
960 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
961 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
962 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
963 OverrideDataType(output, dataType),
964 cLayer->GetParameters(),
965 reason);
966 break;
967 }
David Beckc2044fe2018-09-05 15:00:38 +0100968 case LayerType::Subtraction:
969 {
970 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
971 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
972 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100973 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100974 OverrideDataType(input0, dataType),
975 OverrideDataType(input1, dataType),
976 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100977 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100978 break;
979 }
Sadik Armaganeff363d2019-04-05 15:25:46 +0100980 case LayerType::Switch:
981 {
982 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
983 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
984 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
985 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
986 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
987 OverrideDataType(input1, dataType),
988 OverrideDataType(output0, dataType),
989 OverrideDataType(output1, dataType),
990 reason);
991 break;
992 }
narpra0132b90462018-09-13 11:07:48 +0100993 case LayerType::Mean:
994 {
995 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
996 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
997 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100998 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100999 OverrideDataType(input, dataType),
1000 OverrideDataType(output, dataType),
1001 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001002 reason);
narpra0132b90462018-09-13 11:07:48 +01001003 break;
1004 }
kevmay0190539692018-11-29 08:40:19 +00001005 case LayerType::Minimum:
1006 {
1007 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1008 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1009 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1010 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
1011 OverrideDataType(input1, dataType),
1012 OverrideDataType(output, dataType),
1013 reason);
1014 break;
1015 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001016 case LayerType::Prelu:
1017 {
1018 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1019 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1020 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1021 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
1022 OverrideDataType(alpha, dataType),
1023 OverrideDataType(output, dataType),
1024 reason);
1025 break;
1026 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001027 case LayerType::TransposeConvolution2d:
1028 {
1029 auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
1030
1031 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1032 dataType);
1033 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1034
1035 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1036
1037 Optional<TensorInfo> biases;
1038 if (descriptor.m_BiasEnabled)
1039 {
1040 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
1041 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1042 GetBiasTypeFromWeightsType(dataType));
1043 }
1044
1045 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
1046 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1047
1048 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
1049 output,
1050 descriptor,
1051 weights,
1052 biases,
1053 reason);
1054
1055 break;
1056 }
telsoa014fcda012018-03-09 14:13:49 +00001057 default:
1058 {
1059 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001060 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001061 result = false;
1062 break;
1063 }
1064 }
telsoa014fcda012018-03-09 14:13:49 +00001065 return result;
1066}
1067
David Beckdcb751f2018-10-03 11:42:42 +01001068bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001069 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001070 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001071{
David Beckdcb751f2018-10-03 11:42:42 +01001072 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +01001073 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +00001074}
1075
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001076// Default Implementations
Kevin May868eb142019-09-04 17:29:31 +01001077std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor,
1078 const WorkloadInfo& info) const
1079{
1080 return std::unique_ptr<IWorkload>();
1081}
1082
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001083std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
1084 const WorkloadInfo& info) const
1085{
1086 return std::unique_ptr<IWorkload>();
1087}
1088
1089std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
1090 const WorkloadInfo& info) const
1091{
1092 return std::unique_ptr<IWorkload>();
1093}
1094
Nikhil Rajee391d52019-09-05 17:50:44 +01001095std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
1096 const WorkloadInfo& info) const
1097{
1098 return std::unique_ptr<IWorkload>();
1099}
1100
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001101std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
1102 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
1103{
1104 return std::unique_ptr<IWorkload>();
1105}
1106
1107std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
1108 const WorkloadInfo& Info) const
1109{
1110 return std::unique_ptr<IWorkload>();
1111}
1112
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001113std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
1114 const WorkloadInfo& info) const
1115{
1116 return std::unique_ptr<IWorkload>();
1117}
1118
Jim Flynne242f2d2019-05-22 14:24:13 +01001119std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
Jim Flynn4ed6c832019-05-20 11:02:46 +01001120 const WorkloadInfo& info) const
1121{
1122 return std::unique_ptr<IWorkload>();
1123}
1124
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001125std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
1126 const WorkloadInfo& info) const
1127{
1128 return std::unique_ptr<IWorkload>();
1129}
1130
1131std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
1132 const WorkloadInfo& info) const
1133{
1134 return std::unique_ptr<IWorkload>();
1135}
1136
1137std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
1138 const WorkloadInfo& info) const
1139{
1140 return std::unique_ptr<IWorkload>();
1141}
1142
1143std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
1144 const WorkloadInfo& info) const
1145{
1146 return std::unique_ptr<IWorkload>();
1147}
1148
1149std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
1150 const WorkloadInfo& info) const
1151{
1152 return std::unique_ptr<IWorkload>();
1153}
1154
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001155std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
1156 const WorkloadInfo& info) const
1157{
1158 return std::unique_ptr<IWorkload>();
1159}
1160
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001161std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
1162 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
1163{
1164 return std::unique_ptr<IWorkload>();
1165}
1166
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001167std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
1168 const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const
1169{
1170 return std::unique_ptr<IWorkload>();
1171}
1172
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001173std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
1174 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const
1175{
1176 return std::unique_ptr<IWorkload>();
1177}
1178
1179std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
1180 const WorkloadInfo& info) const
1181{
1182 return std::unique_ptr<IWorkload>();
1183}
1184
1185std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
1186 const WorkloadInfo& Info) const
1187{
1188 return std::unique_ptr<IWorkload>();
1189}
1190
1191std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
1192 const WorkloadInfo& info) const
1193{
1194 return std::unique_ptr<IWorkload>();
1195}
1196
1197std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
1198 const WorkloadInfo& info) const
1199{
1200 return std::unique_ptr<IWorkload>();
1201}
1202
1203std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
1204 const WorkloadInfo& info) const
1205{
1206 return std::unique_ptr<IWorkload>();
1207}
1208
1209std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
1210 const WorkloadInfo& info) const
1211{
1212 return std::unique_ptr<IWorkload>();
1213}
1214
1215std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
1216 const WorkloadInfo& info) const
1217{
1218 return std::unique_ptr<IWorkload>();
1219}
1220
Kevin Mayce5045a2019-10-02 14:07:47 +01001221std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
1222 const InstanceNormalizationQueueDescriptor& descriptor,
1223 const WorkloadInfo& info) const
1224{
1225 return std::unique_ptr<IWorkload>();
1226}
1227
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001228std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
1229 const WorkloadInfo& info) const
1230{
1231 return std::unique_ptr<IWorkload>();
1232}
1233
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001234std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor,
1235 const WorkloadInfo& info) const
1236{
1237 return std::unique_ptr<IWorkload>();
1238}
1239
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001240std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
1241 const WorkloadInfo& info) const
1242{
1243 return std::unique_ptr<IWorkload>();
1244}
1245
1246std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
1247 const WorkloadInfo& info) const
1248{
1249 return std::unique_ptr<IWorkload>();
1250}
1251
1252std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
1253 const WorkloadInfo& Info) const
1254{
1255 return std::unique_ptr<IWorkload>();
1256}
1257
1258std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
1259 const WorkloadInfo& info) const
1260{
1261 return std::unique_ptr<IWorkload>();
1262}
1263
Derek Lambertif674aa02019-08-01 15:56:25 +01001264std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
1265 const WorkloadInfo& info) const
1266{
1267 return std::unique_ptr<IWorkload>();
1268}
1269
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001270std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
1271 const WorkloadInfo& info) const
1272{
1273 return std::unique_ptr<IWorkload>();
1274}
1275
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001276std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
1277 const WorkloadInfo& info) const
1278{
1279 return std::unique_ptr<IWorkload>();
1280}
1281
1282std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
1283 const WorkloadInfo& info) const
1284{
1285 return std::unique_ptr<IWorkload>();
1286}
1287
1288std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
1289 const WorkloadInfo& info) const
1290{
1291 return std::unique_ptr<IWorkload>();
1292}
1293
1294std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
1295 const WorkloadInfo& info) const
1296{
1297 return std::unique_ptr<IWorkload>();
1298}
1299
1300std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
1301 const WorkloadInfo& info) const
1302{
1303 return std::unique_ptr<IWorkload>();
1304}
1305
1306std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
1307 const WorkloadInfo& Info) const
1308{
1309 return std::unique_ptr<IWorkload>();
1310}
1311
1312std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
1313 const WorkloadInfo& info) const
1314{
1315 return std::unique_ptr<IWorkload>();
1316}
1317
1318std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
1319 const WorkloadInfo& info) const
1320{
1321 return std::unique_ptr<IWorkload>();
1322}
1323
1324std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
1325 const WorkloadInfo& info) const
1326{
1327 return std::unique_ptr<IWorkload>();
1328}
1329
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001330std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
1331 const WorkloadInfo &info) const
1332{
1333 return std::unique_ptr<IWorkload>();
1334}
1335
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001336std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
1337 const WorkloadInfo& Info) const
1338{
1339 return std::unique_ptr<IWorkload>();
1340}
1341
James Conroyee18dc82019-07-17 11:27:46 +01001342std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
1343 const WorkloadInfo& info) const
1344{
1345 return std::unique_ptr<IWorkload>();
1346}
1347
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001348std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
1349 const WorkloadInfo& info) const
1350{
1351 return std::unique_ptr<IWorkload>();
1352}
1353
1354std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
1355 const WorkloadInfo& info) const
1356{
1357 return std::unique_ptr<IWorkload>();
1358}
1359
Teresa Charlina9075df2019-06-27 15:41:57 +01001360std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
1361 const WorkloadInfo& info) const
1362{
1363 return std::unique_ptr<IWorkload>();
1364}
1365
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001366std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
1367 const WorkloadInfo& info) const
1368{
1369 return std::unique_ptr<IWorkload>();
1370}
1371
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001372std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& descriptor,
1373 const WorkloadInfo& info) const
1374{
1375 return std::unique_ptr<IWorkload>();
1376}
1377
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001378std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
1379 const WorkloadInfo& info) const
1380{
1381 return std::unique_ptr<IWorkload>();
1382}
1383
1384std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
1385 const WorkloadInfo& info) const
1386{
1387 return std::unique_ptr<IWorkload>();
1388}
1389
1390std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
1391 const WorkloadInfo& info) const
1392{
1393 return std::unique_ptr<IWorkload>();
1394}
1395
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001396std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
1397 const WorkloadInfo& info) const
1398{
1399 return std::unique_ptr<IWorkload>();
1400}
1401
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001402std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor,
1403 const WorkloadInfo& info) const
1404{
1405 return std::unique_ptr<IWorkload>();
1406}
1407
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001408std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
1409 const WorkloadInfo& Info) const
1410{
1411 return std::unique_ptr<IWorkload>();
1412}
1413
1414std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
1415 const WorkloadInfo& info) const
1416{
1417 return std::unique_ptr<IWorkload>();
1418}
1419
Sadik Armaganeff363d2019-04-05 15:25:46 +01001420std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
1421 const WorkloadInfo& info) const
1422{
1423 return std::unique_ptr<IWorkload>();
1424}
1425
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001426std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
1427 const TransposeConvolution2dQueueDescriptor& descriptor,
1428 const WorkloadInfo& info) const
1429{
1430 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001431}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001432
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001433} // namepsace armnn