blob: cca39198e111171ff7feadc5b625d653c83fa540 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
Derek Lambertia9cca6a2019-03-25 15:41:58 +00007#include "WorkloadFactory.hpp"
8
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Layer.hpp>
11#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010012
David Beckb4540be2018-09-24 13:18:27 +010013#include <armnn/Types.hpp>
14#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000015#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000018#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000019#include <backendsCommon/IBackendInternal.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000023#include <boost/iterator/transform_iterator.hpp>
24
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000025#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000026#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000027
telsoa014fcda012018-03-09 14:13:49 +000028namespace armnn
29{
30
telsoa01c577f2c2018-08-31 09:22:23 +010031namespace
32{
telsoa01c577f2c2018-08-31 09:22:23 +010033
David Beck29c75de2018-10-23 13:35:58 +010034const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
35{
36 if (!type)
37 {
38 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010039 }
40
David Beck29c75de2018-10-23 13:35:58 +010041 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010042}
43
David Beck29c75de2018-10-23 13:35:58 +010044} // anonymous namespace
45
David Beck33f0ae02018-10-18 15:13:56 +010046bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010047 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010048 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010049 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000050{
David Beck33f0ae02018-10-18 15:13:56 +010051 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000052 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010053 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
54
David Beck111b5d92018-11-12 14:59:37 +000055 auto const& backendRegistry = BackendRegistryInstance();
56 if (!backendRegistry.IsBackendRegistered(backendId))
57 {
58 std::stringstream ss;
59 ss << connectableLayer.GetName() << " is not supported on " << backendId
60 << " because this backend is not registered.";
61
62 outReasonIfUnsupported = ss.str();
63 return false;
64 }
65
66 auto backendFactory = backendRegistry.GetFactory(backendId);
67 auto backendObject = backendFactory();
68 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010069
telsoa014fcda012018-03-09 14:13:49 +000070 switch(layer.GetType())
71 {
72 case LayerType::Activation:
73 {
74 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
75 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010076 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010077 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010078 OverrideDataType(input, dataType),
79 OverrideDataType(output, dataType),
80 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010081 reason);
telsoa014fcda012018-03-09 14:13:49 +000082 break;
83 }
84 case LayerType::Addition:
85 {
86 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
87 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
88 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010089 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010090 OverrideDataType(input0, dataType),
91 OverrideDataType(input1, dataType),
92 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010093 reason);
telsoa014fcda012018-03-09 14:13:49 +000094 break;
95 }
96 case LayerType::BatchNormalization:
97 {
98 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
99 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100100 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
101 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
102 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
103 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
104 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100105 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100106 OverrideDataType(input, dataType),
107 OverrideDataType(output, dataType),
108 OverrideDataType(mean, dataType),
109 OverrideDataType(var, dataType),
110 OverrideDataType(beta, dataType),
111 OverrideDataType(gamma, dataType),
112 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100113 reason);
telsoa014fcda012018-03-09 14:13:49 +0000114 break;
115 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000116 case LayerType::BatchToSpaceNd:
117 {
118 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
119 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
120 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
121
122 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
123 OverrideDataType(output, dataType),
124 cLayer->GetParameters(),
125 reason);
126 break;
127 }
telsoa014fcda012018-03-09 14:13:49 +0000128 case LayerType::Constant:
129 {
130 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100131 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100132 break;
133 }
134 case LayerType::ConvertFp16ToFp32:
135 {
136 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
137 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100138 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100139 break;
140 }
141 case LayerType::ConvertFp32ToFp16:
142 {
143 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
144 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100145 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000146 break;
147 }
148 case LayerType::Convolution2d:
149 {
150 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100151
152 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
153 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100154 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100155 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
156
arovir01a6824102018-08-28 17:40:45 +0100157 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100158
arovir01a6824102018-08-28 17:40:45 +0100159 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100160 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100161 if (descriptor.m_BiasEnabled)
162 {
David Beck5eec11d2018-10-04 15:43:17 +0100163 biases =
164 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100165 }
166
David Beck33f0ae02018-10-18 15:13:56 +0100167 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100168 input,
169 output,
170 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100171 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100172 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100173 reason);
telsoa014fcda012018-03-09 14:13:49 +0000174 break;
175 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000176 case LayerType::Debug:
177 {
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
180
181 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
182 OverrideDataType(output, dataType),
183 reason);
184 break;
185 }
telsoa014fcda012018-03-09 14:13:49 +0000186 case LayerType::DepthwiseConvolution2d:
187 {
188 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
190 dataType);
191 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
192 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
193
telsoa01c577f2c2018-08-31 09:22:23 +0100194 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100195
196 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100197 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100198 if (descriptor.m_BiasEnabled)
199 {
David Beck5eec11d2018-10-04 15:43:17 +0100200 biases =
201 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100202 }
telsoa01c577f2c2018-08-31 09:22:23 +0100203
David Beck33f0ae02018-10-18 15:13:56 +0100204 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100205 input,
206 output,
207 descriptor,
208 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100209 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100210 reason);
telsoa014fcda012018-03-09 14:13:49 +0000211 break;
212 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000213 case LayerType::Dequantize:
214 {
215 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
216 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
217
218 result = layerSupportObject->IsDequantizeSupported(OverrideDataType(input, dataType),
219 OverrideDataType(output, DataType::Float32),
220 reason);
221 break;
222 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000223 case LayerType::DetectionPostProcess:
224 {
225 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
226 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
227 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
228 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
229 result = layerSupportObject->IsDetectionPostProcessSupported(input0,
230 input1,
231 descriptor,
232 reason);
233 break;
234 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000235 case LayerType::Equal:
236 {
237 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
238 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
239 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
240 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
241 OverrideDataType(input1, dataType),
242 OverrideDataType(output, dataType),
243 reason);
244 break;
245 }
telsoa014fcda012018-03-09 14:13:49 +0000246 case LayerType::FakeQuantization:
247 {
248 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
249 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100250 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
251 cLayer->GetParameters(),
252 reason);
telsoa014fcda012018-03-09 14:13:49 +0000253 break;
254 }
255 case LayerType::Floor:
256 {
257 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
258 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100259 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
260 OverrideDataType(output, dataType),
261 reason);
telsoa014fcda012018-03-09 14:13:49 +0000262 break;
263 }
264 case LayerType::FullyConnected:
265 {
266 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
267 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100268 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
269 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
270
271 TensorInfo biasInfo;
272 const TensorInfo * biasInfoPtr = nullptr;
273 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
274 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
275 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
276
277 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
278 if (descriptor.m_BiasEnabled)
279 {
280 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
281 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
282 biasInfoPtr = &biasInfo;
283 }
284 else
285 {
286 // If biases are not enabled pass a dummy tensorinfo for the validation
287 switch(input.GetDataType())
288 {
289 case DataType::Float16:
290 {
291 biasInfoPtr = &dummyFloat16Bias;
292 break;
293 }
294 case DataType::Float32:
295 {
296 biasInfoPtr = &dummyFloat32Bias;
297 break;
298 }
299 case DataType::QuantisedAsymm8:
300 {
301 biasInfoPtr = &dummyQA8Bias;
302 break;
303 }
304 default:
305 {
306 BOOST_ASSERT_MSG(false, "Unexpected bias type");
307 }
308 }
309 }
310
David Beck33f0ae02018-10-18 15:13:56 +0100311 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100312 OverrideDataType(input, dataType),
313 OverrideDataType(output, dataType),
314 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
315 *biasInfoPtr,
316 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100317 reason);
telsoa014fcda012018-03-09 14:13:49 +0000318 break;
319 }
narpra01b89b05f2019-01-16 09:53:09 +0000320 case LayerType::Gather:
321 {
322 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
323 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
324 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
325 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
326 OverrideDataType(input1, dataType),
327 OverrideDataType(output, dataType),
328 reason);
329 break;
330 }
telsoa014fcda012018-03-09 14:13:49 +0000331 case LayerType::Input:
332 {
333 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100334 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000335 break;
336 }
337 case LayerType::L2Normalization:
338 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100339 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
340 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
341
telsoa014fcda012018-03-09 14:13:49 +0000342 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100343 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100344
David Beck33f0ae02018-10-18 15:13:56 +0100345 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100346 OverrideDataType(input, dataType),
347 OverrideDataType(output, dataType),
348 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100349 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100350 break;
351 }
352 case LayerType::Lstm:
353 {
354 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
355 const LstmDescriptor& descriptor = cLayer->GetParameters();
356
357 // All inputs.
358 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
359 dataType);
360 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
361 dataType);
362 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
363 dataType);
364 // All outputs
365 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
366 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
367 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
368 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
369
370 // Basic parameters
371 const TensorInfo& inputToForgetWeights
372 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
373 const TensorInfo& inputToCellWeights
374 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
375 const TensorInfo& inputToOutputWeights
376 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
377 const TensorInfo& recurrentToForgetWeights
378 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
379 const TensorInfo& recurrentToCellWeights
380 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
381 const TensorInfo& recurrentToOutputWeights
382 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
383 const TensorInfo& forgetGateBias
384 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
385 const TensorInfo& cellBias
386 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
387 const TensorInfo& outputGateBias
388 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
389
390 // Optional parameters
391 const TensorInfo* inputToInputWeights = nullptr;
392 const TensorInfo* recurrentToInputWeights = nullptr;
393 const TensorInfo* cellToInputWeights = nullptr;
394 const TensorInfo* inputGateBias = nullptr;
395 const TensorInfo* projectionWeights = nullptr;
396 const TensorInfo* projectionBias = nullptr;
397 const TensorInfo* cellToForgetWeights = nullptr;
398 const TensorInfo* cellToOutputWeights = nullptr;
399
400 TensorInfo optInputToInputWeights;
401 TensorInfo optRecurrentToInputWeights;
402 TensorInfo optCellToInputWeights;
403 TensorInfo optInputGateBias;
404 TensorInfo optProjectionWeights;
405 TensorInfo optProjectionBias;
406 TensorInfo optCellToForgetWeights;
407 TensorInfo optCellToOutputWeights;
408
409 if(!descriptor.m_CifgEnabled)
410 {
411 optInputToInputWeights =
412 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
413 inputToInputWeights = &optInputToInputWeights;
414
415 optRecurrentToInputWeights =
416 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
417 recurrentToInputWeights = &optRecurrentToInputWeights;
418 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
419 {
420 optCellToInputWeights =
421 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
422 cellToInputWeights = &optCellToInputWeights;
423 }
424 optInputGateBias =
425 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
426 inputGateBias = &optInputGateBias;
427 }
428
429 if(descriptor.m_ProjectionEnabled)
430 {
431 optProjectionWeights =
432 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
433 projectionWeights = &optProjectionWeights;
434 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
435 {
436 optProjectionBias =
437 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
438 projectionBias = &optProjectionBias;
439 }
440 }
441
442 if(descriptor.m_PeepholeEnabled)
443 {
444 optCellToForgetWeights =
445 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
446 cellToForgetWeights = &optCellToForgetWeights;
447 optCellToOutputWeights =
448 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
449 cellToOutputWeights = &optCellToOutputWeights;
450 }
451
David Beck33f0ae02018-10-18 15:13:56 +0100452 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100453 input,
454 outputStateIn,
455 cellStateIn,
456 scratchBuffer,
457 outputStateOut,
458 cellStateOut,
459 output,
460 descriptor,
461 inputToForgetWeights,
462 inputToCellWeights,
463 inputToOutputWeights,
464 recurrentToForgetWeights,
465 recurrentToCellWeights,
466 recurrentToOutputWeights,
467 forgetGateBias,
468 cellBias,
469 outputGateBias,
470 inputToInputWeights,
471 recurrentToInputWeights,
472 cellToInputWeights,
473 inputGateBias,
474 projectionWeights,
475 projectionBias,
476 cellToForgetWeights,
477 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100478 reason);
telsoa014fcda012018-03-09 14:13:49 +0000479 break;
480 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000481 case LayerType::Maximum:
482 {
483 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
484 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
485 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
486
487 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
488 OverrideDataType(input1, dataType),
489 OverrideDataType(output, dataType),
490 reason);
491 break;
492 }
narpra01b89b05f2019-01-16 09:53:09 +0000493 case LayerType::MemCopy:
494 {
495 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
496 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000497
narpra01b89b05f2019-01-16 09:53:09 +0000498 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
499 OverrideDataType(output, dataType),
500 reason);
501 break;
502 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100503 case LayerType::Merge:
504 {
505 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
506 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
507 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
508
509 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
510 OverrideDataType(input1, dataType),
511 OverrideDataType(output, dataType),
512 reason);
513 break;
514 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100515 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000516 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100517 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000518
telsoa01c577f2c2018-08-31 09:22:23 +0100519 // Get vector of all inputs.
520 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000521 {
telsoa01c577f2c2018-08-31 09:22:23 +0100522 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000523 };
telsoa01c577f2c2018-08-31 09:22:23 +0100524 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
525 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
526 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000527
telsoa01c577f2c2018-08-31 09:22:23 +0100528 auto getTensorInfoPtr = [](const TensorInfo& info)
529 {
530 return &info;
531 };
532 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
533 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
534 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000535
Nikhil Raj8599a412018-11-19 14:51:07 +0000536 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
537
Jim Flynne242f2d2019-05-22 14:24:13 +0100538 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
539
540
telsoa014fcda012018-03-09 14:13:49 +0000541 break;
542 }
543 case LayerType::Multiplication:
544 {
545 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
546 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100547 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100548 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100549 OverrideDataType(input0, dataType),
550 OverrideDataType(input1, dataType),
551 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100552 reason);
telsoa014fcda012018-03-09 14:13:49 +0000553 break;
554 }
555 case LayerType::Normalization:
556 {
557 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
558 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
559 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100560 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
561 OverrideDataType(output, dataType),
562 cLayer->GetParameters(),
563 reason);
telsoa014fcda012018-03-09 14:13:49 +0000564 break;
565 }
566 case LayerType::Output:
567 {
568 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100569 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000570 break;
571 }
572 case LayerType::Permute:
573 {
574 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
575 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
576 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100577 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
578 OverrideDataType(output, dataType),
579 cLayer->GetParameters(),
580 reason);
telsoa014fcda012018-03-09 14:13:49 +0000581 break;
582 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100583 case LayerType::Pad:
584 {
585 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
586 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
587 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100588 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100589 OverrideDataType(input, dataType),
590 OverrideDataType(output, dataType),
591 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100592 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100593 break;
594 }
telsoa014fcda012018-03-09 14:13:49 +0000595 case LayerType::Pooling2d:
596 {
597 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
598 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
599 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100600 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
601 OverrideDataType(output, dataType),
602 cLayer->GetParameters(),
603 reason);
telsoa014fcda012018-03-09 14:13:49 +0000604 break;
605 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000606 case LayerType::PreCompiled:
607 {
608 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
609 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
610 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
611 cLayer->GetParameters(),
612 reason);
613 break;
614 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000615 case LayerType::Quantize:
616 {
617 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
618 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
619 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
620 break;
621 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100622 case LayerType::Division:
623 {
624 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
625 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
626 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100627 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100628 OverrideDataType(input0, dataType),
629 OverrideDataType(input1, dataType),
630 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100631 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100632 break;
633 }
telsoa014fcda012018-03-09 14:13:49 +0000634 case LayerType::Reshape:
635 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000636 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000637 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000638 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
639 cLayer->GetParameters(),
640 reason);
telsoa014fcda012018-03-09 14:13:49 +0000641 break;
642 }
643 case LayerType::ResizeBilinear:
644 {
645 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganc625f002018-12-17 11:32:16 +0000646 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
647 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType),
648 OverrideDataType(output, dataType),
649 reason);
telsoa014fcda012018-03-09 14:13:49 +0000650 break;
651 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000652 case LayerType::Rsqrt:
653 {
654 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
655 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
656 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
657 OverrideDataType(output, dataType),
658 reason);
659 break;
660 }
telsoa014fcda012018-03-09 14:13:49 +0000661 case LayerType::Softmax:
662 {
663 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
664 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100665 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100666 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
667 OverrideDataType(output, dataType),
668 cLayer->GetParameters(),
669 reason);
telsoa014fcda012018-03-09 14:13:49 +0000670 break;
671 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000672 case LayerType::SpaceToBatchNd:
673 {
674 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
675 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
676 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
677 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
678 OverrideDataType(output, dataType),
679 cLayer->GetParameters(),
680 reason);
681 break;
682 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100683 case LayerType::SpaceToDepth:
684 {
685 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
686
687 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
688 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
689
690 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
691 OverrideDataType(output, dataType),
692 cLayer->GetParameters(),
693 reason);
694 break;
695 }
telsoa014fcda012018-03-09 14:13:49 +0000696 case LayerType::Splitter:
697 {
698 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
699 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100700
701 // Get vector of all outputs.
702 auto getTensorInfo = [&dataType](const OutputSlot& slot)
703 {
704 return OverrideDataType(slot.GetTensorInfo(), dataType);
705 };
706 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
707 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
708 std::vector<TensorInfo> outputs(beginI, endI);
709
710 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
711
David Beck33f0ae02018-10-18 15:13:56 +0100712 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100713 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100714 cLayer->GetParameters(),
715 reason);
telsoa014fcda012018-03-09 14:13:49 +0000716 break;
717 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000718 case LayerType::StridedSlice:
719 {
720 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
721 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
722 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
723 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
724 OverrideDataType(output, dataType),
725 cLayer->GetParameters(),
726 reason);
727 break;
728 }
David Beckc2044fe2018-09-05 15:00:38 +0100729 case LayerType::Subtraction:
730 {
731 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
732 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
733 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100734 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100735 OverrideDataType(input0, dataType),
736 OverrideDataType(input1, dataType),
737 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100738 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100739 break;
740 }
Sadik Armaganeff363d2019-04-05 15:25:46 +0100741 case LayerType::Switch:
742 {
743 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
744 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
745 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
746 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
747 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
748 OverrideDataType(input1, dataType),
749 OverrideDataType(output0, dataType),
750 OverrideDataType(output1, dataType),
751 reason);
752 break;
753 }
narpra0132b90462018-09-13 11:07:48 +0100754 case LayerType::Mean:
755 {
756 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
757 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
758 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100759 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100760 OverrideDataType(input, dataType),
761 OverrideDataType(output, dataType),
762 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100763 reason);
narpra0132b90462018-09-13 11:07:48 +0100764 break;
765 }
kevmay0190539692018-11-29 08:40:19 +0000766 case LayerType::Minimum:
767 {
768 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
769 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
770 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
771 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
772 OverrideDataType(input1, dataType),
773 OverrideDataType(output, dataType),
774 reason);
775 break;
776 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000777 case LayerType::Greater:
778 {
779 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
780 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
781 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000782 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
783 OverrideDataType(input1, dataType),
784 OverrideDataType(output, DataType::Boolean),
785 reason);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000786 break;
787 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +0100788 case LayerType::Prelu:
789 {
790 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
791 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
792 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
793 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
794 OverrideDataType(alpha, dataType),
795 OverrideDataType(output, dataType),
796 reason);
797 break;
798 }
telsoa014fcda012018-03-09 14:13:49 +0000799 default:
800 {
801 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100802 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000803 result = false;
804 break;
805 }
806 }
telsoa014fcda012018-03-09 14:13:49 +0000807 return result;
808}
809
David Beckdcb751f2018-10-03 11:42:42 +0100810bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100811 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100812 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000813{
David Beckdcb751f2018-10-03 11:42:42 +0100814 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100815 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000816}
817
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000818// Default Implementations
819std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
820 const WorkloadInfo& info) const
821{
822 return std::unique_ptr<IWorkload>();
823}
824
825std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
826 const WorkloadInfo& info) const
827{
828 return std::unique_ptr<IWorkload>();
829}
830
831std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
832 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
833{
834 return std::unique_ptr<IWorkload>();
835}
836
837std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
838 const WorkloadInfo& Info) const
839{
840 return std::unique_ptr<IWorkload>();
841}
842
Jim Flynne242f2d2019-05-22 14:24:13 +0100843std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
Jim Flynn4ed6c832019-05-20 11:02:46 +0100844 const WorkloadInfo& info) const
845{
846 return std::unique_ptr<IWorkload>();
847}
848
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000849std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
850 const WorkloadInfo& info) const
851{
852 return std::unique_ptr<IWorkload>();
853}
854
855std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
856 const WorkloadInfo& info) const
857{
858 return std::unique_ptr<IWorkload>();
859}
860
861std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
862 const WorkloadInfo& info) const
863{
864 return std::unique_ptr<IWorkload>();
865}
866
867std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
868 const WorkloadInfo& info) const
869{
870 return std::unique_ptr<IWorkload>();
871}
872
873std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
874 const WorkloadInfo& info) const
875{
876 return std::unique_ptr<IWorkload>();
877}
878
879std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
880 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
881{
882 return std::unique_ptr<IWorkload>();
883}
884
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000885std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
886 const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const
887{
888 return std::unique_ptr<IWorkload>();
889}
890
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000891std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
892 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const
893{
894 return std::unique_ptr<IWorkload>();
895}
896
897std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
898 const WorkloadInfo& info) const
899{
900 return std::unique_ptr<IWorkload>();
901}
902
903std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
904 const WorkloadInfo& Info) const
905{
906 return std::unique_ptr<IWorkload>();
907}
908
909std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
910 const WorkloadInfo& info) const
911{
912 return std::unique_ptr<IWorkload>();
913}
914
915std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
916 const WorkloadInfo& info) const
917{
918 return std::unique_ptr<IWorkload>();
919}
920
921std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
922 const WorkloadInfo& info) const
923{
924 return std::unique_ptr<IWorkload>();
925}
926
927std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
928 const WorkloadInfo& info) const
929{
930 return std::unique_ptr<IWorkload>();
931}
932
933std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
934 const WorkloadInfo& info) const
935{
936 return std::unique_ptr<IWorkload>();
937}
938
939std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
940 const WorkloadInfo& info) const
941{
942 return std::unique_ptr<IWorkload>();
943}
944
945std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
946 const WorkloadInfo& info) const
947{
948 return std::unique_ptr<IWorkload>();
949}
950
951std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
952 const WorkloadInfo& info) const
953{
954 return std::unique_ptr<IWorkload>();
955}
956
957std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
958 const WorkloadInfo& Info) const
959{
960 return std::unique_ptr<IWorkload>();
961}
962
963std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
964 const WorkloadInfo& info) const
965{
966 return std::unique_ptr<IWorkload>();
967}
968
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100969std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
970 const WorkloadInfo& info) const
971{
972 return std::unique_ptr<IWorkload>();
973}
974
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000975std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
976 const WorkloadInfo& info) const
977{
978 return std::unique_ptr<IWorkload>();
979}
980
981std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
982 const WorkloadInfo& info) const
983{
984 return std::unique_ptr<IWorkload>();
985}
986
987std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
988 const WorkloadInfo& info) const
989{
990 return std::unique_ptr<IWorkload>();
991}
992
993std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
994 const WorkloadInfo& info) const
995{
996 return std::unique_ptr<IWorkload>();
997}
998
999std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
1000 const WorkloadInfo& info) const
1001{
1002 return std::unique_ptr<IWorkload>();
1003}
1004
1005std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
1006 const WorkloadInfo& Info) const
1007{
1008 return std::unique_ptr<IWorkload>();
1009}
1010
1011std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
1012 const WorkloadInfo& info) const
1013{
1014 return std::unique_ptr<IWorkload>();
1015}
1016
1017std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
1018 const WorkloadInfo& info) const
1019{
1020 return std::unique_ptr<IWorkload>();
1021}
1022
1023std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
1024 const WorkloadInfo& info) const
1025{
1026 return std::unique_ptr<IWorkload>();
1027}
1028
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001029std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
1030 const WorkloadInfo &info) const
1031{
1032 return std::unique_ptr<IWorkload>();
1033}
1034
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001035std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
1036 const WorkloadInfo& Info) const
1037{
1038 return std::unique_ptr<IWorkload>();
1039}
1040
1041std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
1042 const WorkloadInfo& info) const
1043{
1044 return std::unique_ptr<IWorkload>();
1045}
1046
1047std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
1048 const WorkloadInfo& info) const
1049{
1050 return std::unique_ptr<IWorkload>();
1051}
1052
1053std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
1054 const WorkloadInfo& info) const
1055{
1056 return std::unique_ptr<IWorkload>();
1057}
1058
1059std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
1060 const WorkloadInfo& info) const
1061{
1062 return std::unique_ptr<IWorkload>();
1063}
1064
1065std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
1066 const WorkloadInfo& info) const
1067{
1068 return std::unique_ptr<IWorkload>();
1069}
1070
1071std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
1072 const WorkloadInfo& info) const
1073{
1074 return std::unique_ptr<IWorkload>();
1075}
1076
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001077std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
1078 const WorkloadInfo& info) const
1079{
1080 return std::unique_ptr<IWorkload>();
1081}
1082
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001083std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
1084 const WorkloadInfo& Info) const
1085{
1086 return std::unique_ptr<IWorkload>();
1087}
1088
1089std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
1090 const WorkloadInfo& info) const
1091{
1092 return std::unique_ptr<IWorkload>();
1093}
1094
Sadik Armaganeff363d2019-04-05 15:25:46 +01001095std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
1096 const WorkloadInfo& info) const
1097{
1098 return std::unique_ptr<IWorkload>();
1099}
1100
surmeh013537c2c2018-05-18 16:31:43 +01001101}