blob: 9d081af8e984c99701464fffce2b46f04ef9ecbc [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
Derek Lambertia9cca6a2019-03-25 15:41:58 +00007#include "WorkloadFactory.hpp"
8
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Layer.hpp>
11#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010012
David Beckb4540be2018-09-24 13:18:27 +010013#include <armnn/Types.hpp>
14#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000015#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000018#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000019#include <backendsCommon/IBackendInternal.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000023#include <boost/iterator/transform_iterator.hpp>
24
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000025#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000026#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000027
telsoa014fcda012018-03-09 14:13:49 +000028namespace armnn
29{
30
telsoa01c577f2c2018-08-31 09:22:23 +010031namespace
32{
telsoa01c577f2c2018-08-31 09:22:23 +010033
David Beck29c75de2018-10-23 13:35:58 +010034const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
35{
36 if (!type)
37 {
38 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010039 }
40
David Beck29c75de2018-10-23 13:35:58 +010041 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010042}
43
David Beck29c75de2018-10-23 13:35:58 +010044} // anonymous namespace
45
David Beck33f0ae02018-10-18 15:13:56 +010046bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010047 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010048 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010049 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000050{
David Beck33f0ae02018-10-18 15:13:56 +010051 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000052 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010053 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
54
David Beck111b5d92018-11-12 14:59:37 +000055 auto const& backendRegistry = BackendRegistryInstance();
56 if (!backendRegistry.IsBackendRegistered(backendId))
57 {
58 std::stringstream ss;
59 ss << connectableLayer.GetName() << " is not supported on " << backendId
60 << " because this backend is not registered.";
61
62 outReasonIfUnsupported = ss.str();
63 return false;
64 }
65
66 auto backendFactory = backendRegistry.GetFactory(backendId);
67 auto backendObject = backendFactory();
68 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010069
telsoa014fcda012018-03-09 14:13:49 +000070 switch(layer.GetType())
71 {
Kevin May868eb142019-09-04 17:29:31 +010072 case LayerType::Abs:
73 {
74 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
75 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
76 result = layerSupportObject->IsAbsSupported(OverrideDataType(input, dataType),
77 OverrideDataType(output, dataType),
78 reason);
79 break;
80 }
telsoa014fcda012018-03-09 14:13:49 +000081 case LayerType::Activation:
82 {
83 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
84 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010085 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010086 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010087 OverrideDataType(input, dataType),
88 OverrideDataType(output, dataType),
89 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010090 reason);
telsoa014fcda012018-03-09 14:13:49 +000091 break;
92 }
93 case LayerType::Addition:
94 {
95 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
96 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
97 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010098 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010099 OverrideDataType(input0, dataType),
100 OverrideDataType(input1, dataType),
101 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100102 reason);
telsoa014fcda012018-03-09 14:13:49 +0000103 break;
104 }
105 case LayerType::BatchNormalization:
106 {
107 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
108 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100109 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
110 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
111 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
112 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
113 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100114 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100115 OverrideDataType(input, dataType),
116 OverrideDataType(output, dataType),
117 OverrideDataType(mean, dataType),
118 OverrideDataType(var, dataType),
119 OverrideDataType(beta, dataType),
120 OverrideDataType(gamma, dataType),
121 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100122 reason);
telsoa014fcda012018-03-09 14:13:49 +0000123 break;
124 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000125 case LayerType::BatchToSpaceNd:
126 {
127 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
128 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
129 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
130
131 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
132 OverrideDataType(output, dataType),
133 cLayer->GetParameters(),
134 reason);
135 break;
136 }
telsoa014fcda012018-03-09 14:13:49 +0000137 case LayerType::Constant:
138 {
139 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100140 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100141 break;
142 }
143 case LayerType::ConvertFp16ToFp32:
144 {
145 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
146 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100147 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100148 break;
149 }
150 case LayerType::ConvertFp32ToFp16:
151 {
152 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
153 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100154 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000155 break;
156 }
157 case LayerType::Convolution2d:
158 {
159 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100160
161 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
162 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100163 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100164 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
165
arovir01a6824102018-08-28 17:40:45 +0100166 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100167
arovir01a6824102018-08-28 17:40:45 +0100168 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100169 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100170 if (descriptor.m_BiasEnabled)
171 {
David Beck5eec11d2018-10-04 15:43:17 +0100172 biases =
173 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100174 }
175
David Beck33f0ae02018-10-18 15:13:56 +0100176 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100177 input,
178 output,
179 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100180 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100181 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100182 reason);
telsoa014fcda012018-03-09 14:13:49 +0000183 break;
184 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000185 case LayerType::Debug:
186 {
187 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
188 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
189
190 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
191 OverrideDataType(output, dataType),
192 reason);
193 break;
194 }
telsoa014fcda012018-03-09 14:13:49 +0000195 case LayerType::DepthwiseConvolution2d:
196 {
197 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100198 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
199 dataType);
200 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
201 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
202
telsoa01c577f2c2018-08-31 09:22:23 +0100203 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100204
205 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100206 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100207 if (descriptor.m_BiasEnabled)
208 {
David Beck5eec11d2018-10-04 15:43:17 +0100209 biases =
210 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100211 }
telsoa01c577f2c2018-08-31 09:22:23 +0100212
David Beck33f0ae02018-10-18 15:13:56 +0100213 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100214 input,
215 output,
216 descriptor,
217 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100218 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100219 reason);
telsoa014fcda012018-03-09 14:13:49 +0000220 break;
221 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000222 case LayerType::Dequantize:
223 {
224 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
225 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
226
227 result = layerSupportObject->IsDequantizeSupported(OverrideDataType(input, dataType),
228 OverrideDataType(output, DataType::Float32),
229 reason);
230 break;
231 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000232 case LayerType::DetectionPostProcess:
233 {
234 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
235 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
236 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
237 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
238 result = layerSupportObject->IsDetectionPostProcessSupported(input0,
239 input1,
240 descriptor,
241 reason);
242 break;
243 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000244 case LayerType::Equal:
245 {
246 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
247 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
248 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
249 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
250 OverrideDataType(input1, dataType),
251 OverrideDataType(output, dataType),
252 reason);
253 break;
254 }
telsoa014fcda012018-03-09 14:13:49 +0000255 case LayerType::FakeQuantization:
256 {
257 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
258 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100259 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
260 cLayer->GetParameters(),
261 reason);
telsoa014fcda012018-03-09 14:13:49 +0000262 break;
263 }
264 case LayerType::Floor:
265 {
266 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
267 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100268 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
269 OverrideDataType(output, dataType),
270 reason);
telsoa014fcda012018-03-09 14:13:49 +0000271 break;
272 }
273 case LayerType::FullyConnected:
274 {
275 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
276 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100277 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
278 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
279
280 TensorInfo biasInfo;
281 const TensorInfo * biasInfoPtr = nullptr;
282 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
283 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
284 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
285
286 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
287 if (descriptor.m_BiasEnabled)
288 {
289 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
290 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
291 biasInfoPtr = &biasInfo;
292 }
293 else
294 {
295 // If biases are not enabled pass a dummy tensorinfo for the validation
296 switch(input.GetDataType())
297 {
298 case DataType::Float16:
299 {
300 biasInfoPtr = &dummyFloat16Bias;
301 break;
302 }
303 case DataType::Float32:
304 {
305 biasInfoPtr = &dummyFloat32Bias;
306 break;
307 }
308 case DataType::QuantisedAsymm8:
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100309 case DataType::QuantisedSymm16:
telsoa01c577f2c2018-08-31 09:22:23 +0100310 {
311 biasInfoPtr = &dummyQA8Bias;
312 break;
313 }
314 default:
315 {
316 BOOST_ASSERT_MSG(false, "Unexpected bias type");
317 }
318 }
319 }
320
David Beck33f0ae02018-10-18 15:13:56 +0100321 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100322 OverrideDataType(input, dataType),
323 OverrideDataType(output, dataType),
324 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
325 *biasInfoPtr,
326 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100327 reason);
telsoa014fcda012018-03-09 14:13:49 +0000328 break;
329 }
narpra01b89b05f2019-01-16 09:53:09 +0000330 case LayerType::Gather:
331 {
332 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
333 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
334 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
335 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100336 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000337 OverrideDataType(output, dataType),
338 reason);
339 break;
340 }
telsoa014fcda012018-03-09 14:13:49 +0000341 case LayerType::Input:
342 {
343 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100344 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000345 break;
346 }
347 case LayerType::L2Normalization:
348 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100349 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
350 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
351
telsoa014fcda012018-03-09 14:13:49 +0000352 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100353 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100354
David Beck33f0ae02018-10-18 15:13:56 +0100355 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100356 OverrideDataType(input, dataType),
357 OverrideDataType(output, dataType),
358 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100359 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100360 break;
361 }
362 case LayerType::Lstm:
363 {
364 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
365 const LstmDescriptor& descriptor = cLayer->GetParameters();
366
367 // All inputs.
368 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
369 dataType);
370 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
371 dataType);
372 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
373 dataType);
374 // All outputs
375 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
376 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
377 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
378 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
379
380 // Basic parameters
381 const TensorInfo& inputToForgetWeights
382 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
383 const TensorInfo& inputToCellWeights
384 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
385 const TensorInfo& inputToOutputWeights
386 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
387 const TensorInfo& recurrentToForgetWeights
388 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
389 const TensorInfo& recurrentToCellWeights
390 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
391 const TensorInfo& recurrentToOutputWeights
392 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
393 const TensorInfo& forgetGateBias
394 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
395 const TensorInfo& cellBias
396 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
397 const TensorInfo& outputGateBias
398 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
399
Jan Eilersd01a83c2019-07-03 18:20:40 +0100400 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100401
Jan Eilersd01a83c2019-07-03 18:20:40 +0100402 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
403 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
404 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
405 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
406 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
407 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
408 paramsInfo.m_ForgetGateBias = &forgetGateBias;
409 paramsInfo.m_CellBias = &cellBias;
410 paramsInfo.m_OutputGateBias = &outputGateBias;
411
412
413 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100414 TensorInfo optInputToInputWeights;
415 TensorInfo optRecurrentToInputWeights;
416 TensorInfo optCellToInputWeights;
417 TensorInfo optInputGateBias;
418 TensorInfo optProjectionWeights;
419 TensorInfo optProjectionBias;
420 TensorInfo optCellToForgetWeights;
421 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100422 TensorInfo optInputLayerNormWeights;
423 TensorInfo optForgetLayerNormWeights;
424 TensorInfo optCellLayerNormWeights;
425 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100426
427 if(!descriptor.m_CifgEnabled)
428 {
429 optInputToInputWeights =
430 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100431 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100432
433 optRecurrentToInputWeights =
434 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100435 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100436 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
437 {
438 optCellToInputWeights =
439 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100440 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100441 }
442 optInputGateBias =
443 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100444 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100445 }
446
447 if(descriptor.m_ProjectionEnabled)
448 {
449 optProjectionWeights =
450 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100451 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100452 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
453 {
454 optProjectionBias =
455 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100456 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100457 }
458 }
459
460 if(descriptor.m_PeepholeEnabled)
461 {
462 optCellToForgetWeights =
463 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100464 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100465 optCellToOutputWeights =
466 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100467 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100468 }
469
Jan Eilers38e05bd2019-06-26 13:10:09 +0100470 if(descriptor.m_LayerNormEnabled)
471 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100472 if (!descriptor.m_CifgEnabled)
473 {
474 optInputLayerNormWeights = OverrideDataType(
475 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
476 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
477 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100478
479 optForgetLayerNormWeights = OverrideDataType(
480 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100481 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100482
483 optCellLayerNormWeights = OverrideDataType(
484 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100485 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100486
487 optOutputLayerNormWeights = OverrideDataType(
488 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100489 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100490 }
491
David Beck33f0ae02018-10-18 15:13:56 +0100492 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100493 input,
494 outputStateIn,
495 cellStateIn,
496 scratchBuffer,
497 outputStateOut,
498 cellStateOut,
499 output,
500 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100501 paramsInfo,
502 reason);
telsoa014fcda012018-03-09 14:13:49 +0000503 break;
504 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000505 case LayerType::Maximum:
506 {
507 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
508 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
509 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
510
511 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
512 OverrideDataType(input1, dataType),
513 OverrideDataType(output, dataType),
514 reason);
515 break;
516 }
narpra01b89b05f2019-01-16 09:53:09 +0000517 case LayerType::MemCopy:
518 {
519 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
520 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000521
narpra01b89b05f2019-01-16 09:53:09 +0000522 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
523 OverrideDataType(output, dataType),
524 reason);
525 break;
526 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100527 case LayerType::MemImport:
528 {
529 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
530 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
531
532 result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
533 OverrideDataType(output, dataType),
534 reason);
535 break;
536 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100537 case LayerType::Merge:
538 {
539 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
540 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
541 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
542
543 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
544 OverrideDataType(input1, dataType),
545 OverrideDataType(output, dataType),
546 reason);
547 break;
548 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100549 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000550 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100551 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000552
telsoa01c577f2c2018-08-31 09:22:23 +0100553 // Get vector of all inputs.
554 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000555 {
telsoa01c577f2c2018-08-31 09:22:23 +0100556 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000557 };
telsoa01c577f2c2018-08-31 09:22:23 +0100558 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
559 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
560 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000561
telsoa01c577f2c2018-08-31 09:22:23 +0100562 auto getTensorInfoPtr = [](const TensorInfo& info)
563 {
564 return &info;
565 };
566 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
567 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
568 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000569
Nikhil Raj8599a412018-11-19 14:51:07 +0000570 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
571
Jim Flynne242f2d2019-05-22 14:24:13 +0100572 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
573
574
telsoa014fcda012018-03-09 14:13:49 +0000575 break;
576 }
577 case LayerType::Multiplication:
578 {
579 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
580 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100581 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100582 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100583 OverrideDataType(input0, dataType),
584 OverrideDataType(input1, dataType),
585 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100586 reason);
telsoa014fcda012018-03-09 14:13:49 +0000587 break;
588 }
589 case LayerType::Normalization:
590 {
591 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
592 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
593 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100594 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
595 OverrideDataType(output, dataType),
596 cLayer->GetParameters(),
597 reason);
telsoa014fcda012018-03-09 14:13:49 +0000598 break;
599 }
600 case LayerType::Output:
601 {
602 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100603 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000604 break;
605 }
606 case LayerType::Permute:
607 {
608 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
609 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
610 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100611 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
612 OverrideDataType(output, dataType),
613 cLayer->GetParameters(),
614 reason);
telsoa014fcda012018-03-09 14:13:49 +0000615 break;
616 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100617 case LayerType::Pad:
618 {
619 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
620 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
621 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100622 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100623 OverrideDataType(input, dataType),
624 OverrideDataType(output, dataType),
625 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100626 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100627 break;
628 }
telsoa014fcda012018-03-09 14:13:49 +0000629 case LayerType::Pooling2d:
630 {
631 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
632 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
633 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100634 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
635 OverrideDataType(output, dataType),
636 cLayer->GetParameters(),
637 reason);
telsoa014fcda012018-03-09 14:13:49 +0000638 break;
639 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000640 case LayerType::PreCompiled:
641 {
642 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
643 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
644 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
645 cLayer->GetParameters(),
646 reason);
647 break;
648 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000649 case LayerType::Quantize:
650 {
651 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
652 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
653 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
654 break;
655 }
James Conroyee18dc82019-07-17 11:27:46 +0100656 case LayerType::QuantizedLstm:
657 {
658 auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer);
659
660 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100661 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
662 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
663 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100664
665 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100666 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
667 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100668
669 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100670 QuantizedLstmInputParamsInfo paramsInfo;
671
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100672 paramsInfo.m_InputToInputWeights =
673 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
674 paramsInfo.m_InputToForgetWeights =
675 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
676 paramsInfo.m_InputToCellWeights =
677 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
678 paramsInfo.m_InputToOutputWeights =
679 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100680
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100681 paramsInfo.m_RecurrentToInputWeights =
682 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
683 paramsInfo.m_RecurrentToForgetWeights =
684 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
685 paramsInfo.m_RecurrentToCellWeights =
686 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
687 paramsInfo.m_RecurrentToOutputWeights =
688 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100689
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100690 paramsInfo.m_InputGateBias =
691 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
692 paramsInfo.m_ForgetGateBias =
693 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
694 paramsInfo.m_CellBias =
695 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
696 paramsInfo.m_OutputGateBias =
697 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100698
699 result = layerSupportObject->IsQuantizedLstmSupported(input,
700 previousCellStateIn,
701 previousOutputIn,
702 cellStateOut,
703 output,
704 paramsInfo,
705 reason);
706 break;
707 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100708 case LayerType::Division:
709 {
710 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
711 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
712 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100713 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100714 OverrideDataType(input0, dataType),
715 OverrideDataType(input1, dataType),
716 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100717 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100718 break;
719 }
telsoa014fcda012018-03-09 14:13:49 +0000720 case LayerType::Reshape:
721 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000722 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000723 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000724 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
725 cLayer->GetParameters(),
726 reason);
telsoa014fcda012018-03-09 14:13:49 +0000727 break;
728 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100729 case LayerType::Resize:
730 {
731 auto cLayer = boost::polymorphic_downcast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100732 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100733 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
734 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
735 OverrideDataType(output, dataType),
736 cLayer->GetParameters(),
737 reason);
738 break;
739 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000740 case LayerType::Rsqrt:
741 {
742 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
743 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
744 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
745 OverrideDataType(output, dataType),
746 reason);
747 break;
748 }
telsoa014fcda012018-03-09 14:13:49 +0000749 case LayerType::Softmax:
750 {
751 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
752 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100753 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100754 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
755 OverrideDataType(output, dataType),
756 cLayer->GetParameters(),
757 reason);
telsoa014fcda012018-03-09 14:13:49 +0000758 break;
759 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000760 case LayerType::SpaceToBatchNd:
761 {
762 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
763 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
764 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
765 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
766 OverrideDataType(output, dataType),
767 cLayer->GetParameters(),
768 reason);
769 break;
770 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100771 case LayerType::SpaceToDepth:
772 {
773 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
774
775 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
776 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
777
778 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
779 OverrideDataType(output, dataType),
780 cLayer->GetParameters(),
781 reason);
782 break;
783 }
telsoa014fcda012018-03-09 14:13:49 +0000784 case LayerType::Splitter:
785 {
786 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
787 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100788
789 // Get vector of all outputs.
790 auto getTensorInfo = [&dataType](const OutputSlot& slot)
791 {
792 return OverrideDataType(slot.GetTensorInfo(), dataType);
793 };
794 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
795 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
796 std::vector<TensorInfo> outputs(beginI, endI);
797
798 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
799
David Beck33f0ae02018-10-18 15:13:56 +0100800 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100801 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100802 cLayer->GetParameters(),
803 reason);
telsoa014fcda012018-03-09 14:13:49 +0000804 break;
805 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100806 case LayerType::Stack:
807 {
808 auto cLayer = boost::polymorphic_downcast<const StackLayer*>(&layer);
809
810 // Get vector of all inputs.
811 auto getTensorInfo = [&dataType](const InputSlot& slot)
812 {
813 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
814 };
815 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
816 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
817 std::vector<TensorInfo> inputs(beginI, endI);
818
819 auto getTensorInfoPtr = [](const TensorInfo& info)
820 {
821 return &info;
822 };
823 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
824 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
825 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
826
827 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
828
829 result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
830
831 break;
832 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000833 case LayerType::StridedSlice:
834 {
835 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
836 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
837 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
838 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
839 OverrideDataType(output, dataType),
840 cLayer->GetParameters(),
841 reason);
842 break;
843 }
David Beckc2044fe2018-09-05 15:00:38 +0100844 case LayerType::Subtraction:
845 {
846 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
847 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
848 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100849 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100850 OverrideDataType(input0, dataType),
851 OverrideDataType(input1, dataType),
852 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100853 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100854 break;
855 }
Sadik Armaganeff363d2019-04-05 15:25:46 +0100856 case LayerType::Switch:
857 {
858 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
859 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
860 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
861 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
862 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
863 OverrideDataType(input1, dataType),
864 OverrideDataType(output0, dataType),
865 OverrideDataType(output1, dataType),
866 reason);
867 break;
868 }
narpra0132b90462018-09-13 11:07:48 +0100869 case LayerType::Mean:
870 {
871 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
872 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
873 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100874 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100875 OverrideDataType(input, dataType),
876 OverrideDataType(output, dataType),
877 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100878 reason);
narpra0132b90462018-09-13 11:07:48 +0100879 break;
880 }
kevmay0190539692018-11-29 08:40:19 +0000881 case LayerType::Minimum:
882 {
883 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
884 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
885 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
886 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
887 OverrideDataType(input1, dataType),
888 OverrideDataType(output, dataType),
889 reason);
890 break;
891 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000892 case LayerType::Greater:
893 {
894 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
895 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
896 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000897 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
898 OverrideDataType(input1, dataType),
899 OverrideDataType(output, DataType::Boolean),
900 reason);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000901 break;
902 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +0100903 case LayerType::Prelu:
904 {
905 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
906 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
907 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
908 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
909 OverrideDataType(alpha, dataType),
910 OverrideDataType(output, dataType),
911 reason);
912 break;
913 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +0100914 case LayerType::TransposeConvolution2d:
915 {
916 auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
917
918 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
919 dataType);
920 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
921
922 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
923
924 Optional<TensorInfo> biases;
925 if (descriptor.m_BiasEnabled)
926 {
927 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
928 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
929 GetBiasTypeFromWeightsType(dataType));
930 }
931
932 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
933 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
934
935 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
936 output,
937 descriptor,
938 weights,
939 biases,
940 reason);
941
942 break;
943 }
telsoa014fcda012018-03-09 14:13:49 +0000944 default:
945 {
946 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100947 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000948 result = false;
949 break;
950 }
951 }
telsoa014fcda012018-03-09 14:13:49 +0000952 return result;
953}
954
David Beckdcb751f2018-10-03 11:42:42 +0100955bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100956 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100957 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000958{
David Beckdcb751f2018-10-03 11:42:42 +0100959 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100960 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000961}
962
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000963// Default Implementations
Kevin May868eb142019-09-04 17:29:31 +0100964std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor,
965 const WorkloadInfo& info) const
966{
967 return std::unique_ptr<IWorkload>();
968}
969
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000970std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
971 const WorkloadInfo& info) const
972{
973 return std::unique_ptr<IWorkload>();
974}
975
976std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
977 const WorkloadInfo& info) const
978{
979 return std::unique_ptr<IWorkload>();
980}
981
982std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
983 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
984{
985 return std::unique_ptr<IWorkload>();
986}
987
988std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
989 const WorkloadInfo& Info) const
990{
991 return std::unique_ptr<IWorkload>();
992}
993
Jim Flynne242f2d2019-05-22 14:24:13 +0100994std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
Jim Flynn4ed6c832019-05-20 11:02:46 +0100995 const WorkloadInfo& info) const
996{
997 return std::unique_ptr<IWorkload>();
998}
999
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001000std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
1001 const WorkloadInfo& info) const
1002{
1003 return std::unique_ptr<IWorkload>();
1004}
1005
1006std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
1007 const WorkloadInfo& info) const
1008{
1009 return std::unique_ptr<IWorkload>();
1010}
1011
1012std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
1013 const WorkloadInfo& info) const
1014{
1015 return std::unique_ptr<IWorkload>();
1016}
1017
1018std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
1019 const WorkloadInfo& info) const
1020{
1021 return std::unique_ptr<IWorkload>();
1022}
1023
1024std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
1025 const WorkloadInfo& info) const
1026{
1027 return std::unique_ptr<IWorkload>();
1028}
1029
1030std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
1031 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
1032{
1033 return std::unique_ptr<IWorkload>();
1034}
1035
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001036std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
1037 const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const
1038{
1039 return std::unique_ptr<IWorkload>();
1040}
1041
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001042std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
1043 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const
1044{
1045 return std::unique_ptr<IWorkload>();
1046}
1047
1048std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
1049 const WorkloadInfo& info) const
1050{
1051 return std::unique_ptr<IWorkload>();
1052}
1053
1054std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
1055 const WorkloadInfo& Info) const
1056{
1057 return std::unique_ptr<IWorkload>();
1058}
1059
1060std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
1061 const WorkloadInfo& info) const
1062{
1063 return std::unique_ptr<IWorkload>();
1064}
1065
1066std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
1067 const WorkloadInfo& info) const
1068{
1069 return std::unique_ptr<IWorkload>();
1070}
1071
1072std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
1073 const WorkloadInfo& info) const
1074{
1075 return std::unique_ptr<IWorkload>();
1076}
1077
1078std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
1079 const WorkloadInfo& info) const
1080{
1081 return std::unique_ptr<IWorkload>();
1082}
1083
1084std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
1085 const WorkloadInfo& info) const
1086{
1087 return std::unique_ptr<IWorkload>();
1088}
1089
1090std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
1091 const WorkloadInfo& info) const
1092{
1093 return std::unique_ptr<IWorkload>();
1094}
1095
1096std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
1097 const WorkloadInfo& info) const
1098{
1099 return std::unique_ptr<IWorkload>();
1100}
1101
1102std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
1103 const WorkloadInfo& info) const
1104{
1105 return std::unique_ptr<IWorkload>();
1106}
1107
1108std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
1109 const WorkloadInfo& Info) const
1110{
1111 return std::unique_ptr<IWorkload>();
1112}
1113
1114std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
1115 const WorkloadInfo& info) const
1116{
1117 return std::unique_ptr<IWorkload>();
1118}
1119
Derek Lambertif674aa02019-08-01 15:56:25 +01001120std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
1121 const WorkloadInfo& info) const
1122{
1123 return std::unique_ptr<IWorkload>();
1124}
1125
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001126std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
1127 const WorkloadInfo& info) const
1128{
1129 return std::unique_ptr<IWorkload>();
1130}
1131
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001132std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
1133 const WorkloadInfo& info) const
1134{
1135 return std::unique_ptr<IWorkload>();
1136}
1137
1138std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
1139 const WorkloadInfo& info) const
1140{
1141 return std::unique_ptr<IWorkload>();
1142}
1143
1144std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
1145 const WorkloadInfo& info) const
1146{
1147 return std::unique_ptr<IWorkload>();
1148}
1149
1150std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
1151 const WorkloadInfo& info) const
1152{
1153 return std::unique_ptr<IWorkload>();
1154}
1155
1156std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
1157 const WorkloadInfo& info) const
1158{
1159 return std::unique_ptr<IWorkload>();
1160}
1161
1162std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
1163 const WorkloadInfo& Info) const
1164{
1165 return std::unique_ptr<IWorkload>();
1166}
1167
1168std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
1169 const WorkloadInfo& info) const
1170{
1171 return std::unique_ptr<IWorkload>();
1172}
1173
1174std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
1175 const WorkloadInfo& info) const
1176{
1177 return std::unique_ptr<IWorkload>();
1178}
1179
1180std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
1181 const WorkloadInfo& info) const
1182{
1183 return std::unique_ptr<IWorkload>();
1184}
1185
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001186std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
1187 const WorkloadInfo &info) const
1188{
1189 return std::unique_ptr<IWorkload>();
1190}
1191
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001192std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
1193 const WorkloadInfo& Info) const
1194{
1195 return std::unique_ptr<IWorkload>();
1196}
1197
James Conroyee18dc82019-07-17 11:27:46 +01001198std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
1199 const WorkloadInfo& info) const
1200{
1201 return std::unique_ptr<IWorkload>();
1202}
1203
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001204std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
1205 const WorkloadInfo& info) const
1206{
1207 return std::unique_ptr<IWorkload>();
1208}
1209
1210std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
1211 const WorkloadInfo& info) const
1212{
1213 return std::unique_ptr<IWorkload>();
1214}
1215
Teresa Charlina9075df2019-06-27 15:41:57 +01001216std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
1217 const WorkloadInfo& info) const
1218{
1219 return std::unique_ptr<IWorkload>();
1220}
1221
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001222std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
1223 const WorkloadInfo& info) const
1224{
1225 return std::unique_ptr<IWorkload>();
1226}
1227
1228std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
1229 const WorkloadInfo& info) const
1230{
1231 return std::unique_ptr<IWorkload>();
1232}
1233
1234std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
1235 const WorkloadInfo& info) const
1236{
1237 return std::unique_ptr<IWorkload>();
1238}
1239
1240std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
1241 const WorkloadInfo& info) const
1242{
1243 return std::unique_ptr<IWorkload>();
1244}
1245
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001246std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
1247 const WorkloadInfo& info) const
1248{
1249 return std::unique_ptr<IWorkload>();
1250}
1251
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001252std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor,
1253 const WorkloadInfo& info) const
1254{
1255 return std::unique_ptr<IWorkload>();
1256}
1257
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001258std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
1259 const WorkloadInfo& Info) const
1260{
1261 return std::unique_ptr<IWorkload>();
1262}
1263
1264std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
1265 const WorkloadInfo& info) const
1266{
1267 return std::unique_ptr<IWorkload>();
1268}
1269
Sadik Armaganeff363d2019-04-05 15:25:46 +01001270std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
1271 const WorkloadInfo& info) const
1272{
1273 return std::unique_ptr<IWorkload>();
1274}
1275
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001276std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
1277 const TransposeConvolution2dQueueDescriptor& descriptor,
1278 const WorkloadInfo& info) const
1279{
1280 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001281}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001282
1283} // namepsace armnn