blob: 678d33050845223684e49e7faee5be4c58aa2313 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
Derek Lambertia9cca6a2019-03-25 15:41:58 +00007#include "WorkloadFactory.hpp"
8
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Layer.hpp>
11#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010012
David Beckb4540be2018-09-24 13:18:27 +010013#include <armnn/Types.hpp>
14#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000015#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000018#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000019#include <backendsCommon/IBackendInternal.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000023#include <boost/iterator/transform_iterator.hpp>
24
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000025#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000026#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000027
telsoa014fcda012018-03-09 14:13:49 +000028namespace armnn
29{
30
telsoa01c577f2c2018-08-31 09:22:23 +010031namespace
32{
telsoa01c577f2c2018-08-31 09:22:23 +010033
David Beck29c75de2018-10-23 13:35:58 +010034const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
35{
36 if (!type)
37 {
38 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010039 }
40
David Beck29c75de2018-10-23 13:35:58 +010041 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010042}
43
David Beck29c75de2018-10-23 13:35:58 +010044} // anonymous namespace
45
David Beck33f0ae02018-10-18 15:13:56 +010046bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010047 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010048 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010049 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000050{
David Beck33f0ae02018-10-18 15:13:56 +010051 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000052 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010053 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
54
David Beck111b5d92018-11-12 14:59:37 +000055 auto const& backendRegistry = BackendRegistryInstance();
56 if (!backendRegistry.IsBackendRegistered(backendId))
57 {
58 std::stringstream ss;
59 ss << connectableLayer.GetName() << " is not supported on " << backendId
60 << " because this backend is not registered.";
61
62 outReasonIfUnsupported = ss.str();
63 return false;
64 }
65
66 auto backendFactory = backendRegistry.GetFactory(backendId);
67 auto backendObject = backendFactory();
68 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010069
telsoa014fcda012018-03-09 14:13:49 +000070 switch(layer.GetType())
71 {
72 case LayerType::Activation:
73 {
74 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
75 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010076 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010077 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010078 OverrideDataType(input, dataType),
79 OverrideDataType(output, dataType),
80 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010081 reason);
telsoa014fcda012018-03-09 14:13:49 +000082 break;
83 }
84 case LayerType::Addition:
85 {
86 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
87 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
88 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010089 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010090 OverrideDataType(input0, dataType),
91 OverrideDataType(input1, dataType),
92 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010093 reason);
telsoa014fcda012018-03-09 14:13:49 +000094 break;
95 }
96 case LayerType::BatchNormalization:
97 {
98 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
99 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100100 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
101 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
102 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
103 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
104 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100105 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100106 OverrideDataType(input, dataType),
107 OverrideDataType(output, dataType),
108 OverrideDataType(mean, dataType),
109 OverrideDataType(var, dataType),
110 OverrideDataType(beta, dataType),
111 OverrideDataType(gamma, dataType),
112 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100113 reason);
telsoa014fcda012018-03-09 14:13:49 +0000114 break;
115 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000116 case LayerType::BatchToSpaceNd:
117 {
118 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
119 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
120 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
121
122 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
123 OverrideDataType(output, dataType),
124 cLayer->GetParameters(),
125 reason);
126 break;
127 }
telsoa014fcda012018-03-09 14:13:49 +0000128 case LayerType::Constant:
129 {
130 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100131 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100132 break;
133 }
134 case LayerType::ConvertFp16ToFp32:
135 {
136 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
137 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100138 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100139 break;
140 }
141 case LayerType::ConvertFp32ToFp16:
142 {
143 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
144 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100145 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000146 break;
147 }
148 case LayerType::Convolution2d:
149 {
150 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100151
152 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
153 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100154 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100155 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
156
arovir01a6824102018-08-28 17:40:45 +0100157 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100158
arovir01a6824102018-08-28 17:40:45 +0100159 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100160 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100161 if (descriptor.m_BiasEnabled)
162 {
David Beck5eec11d2018-10-04 15:43:17 +0100163 biases =
164 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100165 }
166
David Beck33f0ae02018-10-18 15:13:56 +0100167 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100168 input,
169 output,
170 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100171 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100172 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100173 reason);
telsoa014fcda012018-03-09 14:13:49 +0000174 break;
175 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000176 case LayerType::Debug:
177 {
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
180
181 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
182 OverrideDataType(output, dataType),
183 reason);
184 break;
185 }
telsoa014fcda012018-03-09 14:13:49 +0000186 case LayerType::DepthwiseConvolution2d:
187 {
188 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
190 dataType);
191 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
192 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
193
telsoa01c577f2c2018-08-31 09:22:23 +0100194 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100195
196 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100197 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100198 if (descriptor.m_BiasEnabled)
199 {
David Beck5eec11d2018-10-04 15:43:17 +0100200 biases =
201 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100202 }
telsoa01c577f2c2018-08-31 09:22:23 +0100203
David Beck33f0ae02018-10-18 15:13:56 +0100204 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100205 input,
206 output,
207 descriptor,
208 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100209 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100210 reason);
telsoa014fcda012018-03-09 14:13:49 +0000211 break;
212 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000213 case LayerType::Dequantize:
214 {
215 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
216 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
217
218 result = layerSupportObject->IsDequantizeSupported(OverrideDataType(input, dataType),
219 OverrideDataType(output, DataType::Float32),
220 reason);
221 break;
222 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000223 case LayerType::DetectionPostProcess:
224 {
225 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
226 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
227 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
228 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
229 result = layerSupportObject->IsDetectionPostProcessSupported(input0,
230 input1,
231 descriptor,
232 reason);
233 break;
234 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000235 case LayerType::Equal:
236 {
237 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
238 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
239 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
240 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
241 OverrideDataType(input1, dataType),
242 OverrideDataType(output, dataType),
243 reason);
244 break;
245 }
telsoa014fcda012018-03-09 14:13:49 +0000246 case LayerType::FakeQuantization:
247 {
248 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
249 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100250 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
251 cLayer->GetParameters(),
252 reason);
telsoa014fcda012018-03-09 14:13:49 +0000253 break;
254 }
255 case LayerType::Floor:
256 {
257 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
258 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100259 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
260 OverrideDataType(output, dataType),
261 reason);
telsoa014fcda012018-03-09 14:13:49 +0000262 break;
263 }
264 case LayerType::FullyConnected:
265 {
266 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
267 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100268 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
269 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
270
271 TensorInfo biasInfo;
272 const TensorInfo * biasInfoPtr = nullptr;
273 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
274 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
275 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
276
277 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
278 if (descriptor.m_BiasEnabled)
279 {
280 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
281 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
282 biasInfoPtr = &biasInfo;
283 }
284 else
285 {
286 // If biases are not enabled pass a dummy tensorinfo for the validation
287 switch(input.GetDataType())
288 {
289 case DataType::Float16:
290 {
291 biasInfoPtr = &dummyFloat16Bias;
292 break;
293 }
294 case DataType::Float32:
295 {
296 biasInfoPtr = &dummyFloat32Bias;
297 break;
298 }
299 case DataType::QuantisedAsymm8:
300 {
301 biasInfoPtr = &dummyQA8Bias;
302 break;
303 }
304 default:
305 {
306 BOOST_ASSERT_MSG(false, "Unexpected bias type");
307 }
308 }
309 }
310
David Beck33f0ae02018-10-18 15:13:56 +0100311 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100312 OverrideDataType(input, dataType),
313 OverrideDataType(output, dataType),
314 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
315 *biasInfoPtr,
316 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100317 reason);
telsoa014fcda012018-03-09 14:13:49 +0000318 break;
319 }
narpra01b89b05f2019-01-16 09:53:09 +0000320 case LayerType::Gather:
321 {
322 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
323 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
324 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
325 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
326 OverrideDataType(input1, dataType),
327 OverrideDataType(output, dataType),
328 reason);
329 break;
330 }
telsoa014fcda012018-03-09 14:13:49 +0000331 case LayerType::Input:
332 {
333 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100334 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000335 break;
336 }
337 case LayerType::L2Normalization:
338 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100339 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
340 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
341
telsoa014fcda012018-03-09 14:13:49 +0000342 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100343 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100344
David Beck33f0ae02018-10-18 15:13:56 +0100345 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100346 OverrideDataType(input, dataType),
347 OverrideDataType(output, dataType),
348 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100349 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100350 break;
351 }
352 case LayerType::Lstm:
353 {
354 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
355 const LstmDescriptor& descriptor = cLayer->GetParameters();
356
357 // All inputs.
358 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
359 dataType);
360 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
361 dataType);
362 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
363 dataType);
364 // All outputs
365 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
366 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
367 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
368 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
369
370 // Basic parameters
371 const TensorInfo& inputToForgetWeights
372 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
373 const TensorInfo& inputToCellWeights
374 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
375 const TensorInfo& inputToOutputWeights
376 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
377 const TensorInfo& recurrentToForgetWeights
378 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
379 const TensorInfo& recurrentToCellWeights
380 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
381 const TensorInfo& recurrentToOutputWeights
382 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
383 const TensorInfo& forgetGateBias
384 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
385 const TensorInfo& cellBias
386 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
387 const TensorInfo& outputGateBias
388 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
389
390 // Optional parameters
391 const TensorInfo* inputToInputWeights = nullptr;
392 const TensorInfo* recurrentToInputWeights = nullptr;
393 const TensorInfo* cellToInputWeights = nullptr;
394 const TensorInfo* inputGateBias = nullptr;
395 const TensorInfo* projectionWeights = nullptr;
396 const TensorInfo* projectionBias = nullptr;
397 const TensorInfo* cellToForgetWeights = nullptr;
398 const TensorInfo* cellToOutputWeights = nullptr;
399
400 TensorInfo optInputToInputWeights;
401 TensorInfo optRecurrentToInputWeights;
402 TensorInfo optCellToInputWeights;
403 TensorInfo optInputGateBias;
404 TensorInfo optProjectionWeights;
405 TensorInfo optProjectionBias;
406 TensorInfo optCellToForgetWeights;
407 TensorInfo optCellToOutputWeights;
408
409 if(!descriptor.m_CifgEnabled)
410 {
411 optInputToInputWeights =
412 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
413 inputToInputWeights = &optInputToInputWeights;
414
415 optRecurrentToInputWeights =
416 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
417 recurrentToInputWeights = &optRecurrentToInputWeights;
418 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
419 {
420 optCellToInputWeights =
421 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
422 cellToInputWeights = &optCellToInputWeights;
423 }
424 optInputGateBias =
425 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
426 inputGateBias = &optInputGateBias;
427 }
428
429 if(descriptor.m_ProjectionEnabled)
430 {
431 optProjectionWeights =
432 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
433 projectionWeights = &optProjectionWeights;
434 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
435 {
436 optProjectionBias =
437 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
438 projectionBias = &optProjectionBias;
439 }
440 }
441
442 if(descriptor.m_PeepholeEnabled)
443 {
444 optCellToForgetWeights =
445 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
446 cellToForgetWeights = &optCellToForgetWeights;
447 optCellToOutputWeights =
448 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
449 cellToOutputWeights = &optCellToOutputWeights;
450 }
451
David Beck33f0ae02018-10-18 15:13:56 +0100452 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100453 input,
454 outputStateIn,
455 cellStateIn,
456 scratchBuffer,
457 outputStateOut,
458 cellStateOut,
459 output,
460 descriptor,
461 inputToForgetWeights,
462 inputToCellWeights,
463 inputToOutputWeights,
464 recurrentToForgetWeights,
465 recurrentToCellWeights,
466 recurrentToOutputWeights,
467 forgetGateBias,
468 cellBias,
469 outputGateBias,
470 inputToInputWeights,
471 recurrentToInputWeights,
472 cellToInputWeights,
473 inputGateBias,
474 projectionWeights,
475 projectionBias,
476 cellToForgetWeights,
477 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100478 reason);
telsoa014fcda012018-03-09 14:13:49 +0000479 break;
480 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000481 case LayerType::Maximum:
482 {
483 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
484 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
485 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
486
487 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
488 OverrideDataType(input1, dataType),
489 OverrideDataType(output, dataType),
490 reason);
491 break;
492 }
narpra01b89b05f2019-01-16 09:53:09 +0000493 case LayerType::MemCopy:
494 {
495 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
496 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000497
narpra01b89b05f2019-01-16 09:53:09 +0000498 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
499 OverrideDataType(output, dataType),
500 reason);
501 break;
502 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100503 case LayerType::Merge:
504 {
505 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
506 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
507 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
508
509 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
510 OverrideDataType(input1, dataType),
511 OverrideDataType(output, dataType),
512 reason);
513 break;
514 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100515 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000516 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100517 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000518
telsoa01c577f2c2018-08-31 09:22:23 +0100519 // Get vector of all inputs.
520 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000521 {
telsoa01c577f2c2018-08-31 09:22:23 +0100522 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000523 };
telsoa01c577f2c2018-08-31 09:22:23 +0100524 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
525 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
526 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000527
telsoa01c577f2c2018-08-31 09:22:23 +0100528 auto getTensorInfoPtr = [](const TensorInfo& info)
529 {
530 return &info;
531 };
532 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
533 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
534 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000535
Nikhil Raj8599a412018-11-19 14:51:07 +0000536 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
537
Jim Flynne242f2d2019-05-22 14:24:13 +0100538 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
539
540
telsoa014fcda012018-03-09 14:13:49 +0000541 break;
542 }
543 case LayerType::Multiplication:
544 {
545 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
546 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100547 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100548 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100549 OverrideDataType(input0, dataType),
550 OverrideDataType(input1, dataType),
551 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100552 reason);
telsoa014fcda012018-03-09 14:13:49 +0000553 break;
554 }
555 case LayerType::Normalization:
556 {
557 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
558 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
559 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100560 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
561 OverrideDataType(output, dataType),
562 cLayer->GetParameters(),
563 reason);
telsoa014fcda012018-03-09 14:13:49 +0000564 break;
565 }
566 case LayerType::Output:
567 {
568 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100569 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000570 break;
571 }
572 case LayerType::Permute:
573 {
574 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
575 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
576 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100577 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
578 OverrideDataType(output, dataType),
579 cLayer->GetParameters(),
580 reason);
telsoa014fcda012018-03-09 14:13:49 +0000581 break;
582 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100583 case LayerType::Pad:
584 {
585 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
586 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
587 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100588 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100589 OverrideDataType(input, dataType),
590 OverrideDataType(output, dataType),
591 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100592 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100593 break;
594 }
telsoa014fcda012018-03-09 14:13:49 +0000595 case LayerType::Pooling2d:
596 {
597 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
598 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
599 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100600 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
601 OverrideDataType(output, dataType),
602 cLayer->GetParameters(),
603 reason);
telsoa014fcda012018-03-09 14:13:49 +0000604 break;
605 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000606 case LayerType::PreCompiled:
607 {
608 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
609 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
610 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
611 cLayer->GetParameters(),
612 reason);
613 break;
614 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000615 case LayerType::Quantize:
616 {
617 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
618 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
619 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
620 break;
621 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100622 case LayerType::Division:
623 {
624 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
625 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
626 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100627 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100628 OverrideDataType(input0, dataType),
629 OverrideDataType(input1, dataType),
630 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100631 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100632 break;
633 }
telsoa014fcda012018-03-09 14:13:49 +0000634 case LayerType::Reshape:
635 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000636 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000637 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000638 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
639 cLayer->GetParameters(),
640 reason);
telsoa014fcda012018-03-09 14:13:49 +0000641 break;
642 }
643 case LayerType::ResizeBilinear:
644 {
645 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganc625f002018-12-17 11:32:16 +0000646 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
647 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType),
648 OverrideDataType(output, dataType),
649 reason);
telsoa014fcda012018-03-09 14:13:49 +0000650 break;
651 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000652 case LayerType::Rsqrt:
653 {
654 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
655 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
656 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
657 OverrideDataType(output, dataType),
658 reason);
659 break;
660 }
telsoa014fcda012018-03-09 14:13:49 +0000661 case LayerType::Softmax:
662 {
663 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
664 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100665 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100666 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
667 OverrideDataType(output, dataType),
668 cLayer->GetParameters(),
669 reason);
telsoa014fcda012018-03-09 14:13:49 +0000670 break;
671 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000672 case LayerType::SpaceToBatchNd:
673 {
674 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
675 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
676 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
677 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
678 OverrideDataType(output, dataType),
679 cLayer->GetParameters(),
680 reason);
681 break;
682 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100683 case LayerType::SpaceToDepth:
684 {
685 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
686
687 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
688 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
689
690 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
691 OverrideDataType(output, dataType),
692 cLayer->GetParameters(),
693 reason);
694 break;
695 }
telsoa014fcda012018-03-09 14:13:49 +0000696 case LayerType::Splitter:
697 {
698 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
699 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100700
701 // Get vector of all outputs.
702 auto getTensorInfo = [&dataType](const OutputSlot& slot)
703 {
704 return OverrideDataType(slot.GetTensorInfo(), dataType);
705 };
706 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
707 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
708 std::vector<TensorInfo> outputs(beginI, endI);
709
710 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
711
David Beck33f0ae02018-10-18 15:13:56 +0100712 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100713 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100714 cLayer->GetParameters(),
715 reason);
telsoa014fcda012018-03-09 14:13:49 +0000716 break;
717 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000718 case LayerType::StridedSlice:
719 {
720 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
721 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
722 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
723 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
724 OverrideDataType(output, dataType),
725 cLayer->GetParameters(),
726 reason);
727 break;
728 }
David Beckc2044fe2018-09-05 15:00:38 +0100729 case LayerType::Subtraction:
730 {
731 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
732 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
733 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100734 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100735 OverrideDataType(input0, dataType),
736 OverrideDataType(input1, dataType),
737 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100738 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100739 break;
740 }
Sadik Armaganeff363d2019-04-05 15:25:46 +0100741 case LayerType::Switch:
742 {
743 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
744 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
745 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
746 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
747 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
748 OverrideDataType(input1, dataType),
749 OverrideDataType(output0, dataType),
750 OverrideDataType(output1, dataType),
751 reason);
752 break;
753 }
narpra0132b90462018-09-13 11:07:48 +0100754 case LayerType::Mean:
755 {
756 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
757 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
758 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100759 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100760 OverrideDataType(input, dataType),
761 OverrideDataType(output, dataType),
762 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100763 reason);
narpra0132b90462018-09-13 11:07:48 +0100764 break;
765 }
kevmay0190539692018-11-29 08:40:19 +0000766 case LayerType::Minimum:
767 {
768 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
769 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
770 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
771 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
772 OverrideDataType(input1, dataType),
773 OverrideDataType(output, dataType),
774 reason);
775 break;
776 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000777 case LayerType::Greater:
778 {
779 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
780 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
781 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000782 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
783 OverrideDataType(input1, dataType),
784 OverrideDataType(output, DataType::Boolean),
785 reason);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000786 break;
787 }
telsoa014fcda012018-03-09 14:13:49 +0000788 default:
789 {
790 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100791 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000792 result = false;
793 break;
794 }
795 }
telsoa014fcda012018-03-09 14:13:49 +0000796 return result;
797}
798
David Beckdcb751f2018-10-03 11:42:42 +0100799bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100800 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100801 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000802{
David Beckdcb751f2018-10-03 11:42:42 +0100803 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100804 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000805}
806
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000807// Default Implementations
808std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
809 const WorkloadInfo& info) const
810{
811 return std::unique_ptr<IWorkload>();
812}
813
814std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
815 const WorkloadInfo& info) const
816{
817 return std::unique_ptr<IWorkload>();
818}
819
820std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
821 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
822{
823 return std::unique_ptr<IWorkload>();
824}
825
826std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
827 const WorkloadInfo& Info) const
828{
829 return std::unique_ptr<IWorkload>();
830}
831
Jim Flynne242f2d2019-05-22 14:24:13 +0100832std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
Jim Flynn4ed6c832019-05-20 11:02:46 +0100833 const WorkloadInfo& info) const
834{
835 return std::unique_ptr<IWorkload>();
836}
837
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000838std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
839 const WorkloadInfo& info) const
840{
841 return std::unique_ptr<IWorkload>();
842}
843
844std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
845 const WorkloadInfo& info) const
846{
847 return std::unique_ptr<IWorkload>();
848}
849
850std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
851 const WorkloadInfo& info) const
852{
853 return std::unique_ptr<IWorkload>();
854}
855
856std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
857 const WorkloadInfo& info) const
858{
859 return std::unique_ptr<IWorkload>();
860}
861
862std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
863 const WorkloadInfo& info) const
864{
865 return std::unique_ptr<IWorkload>();
866}
867
868std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
869 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
870{
871 return std::unique_ptr<IWorkload>();
872}
873
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000874std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
875 const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const
876{
877 return std::unique_ptr<IWorkload>();
878}
879
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000880std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
881 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const
882{
883 return std::unique_ptr<IWorkload>();
884}
885
886std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
887 const WorkloadInfo& info) const
888{
889 return std::unique_ptr<IWorkload>();
890}
891
892std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
893 const WorkloadInfo& Info) const
894{
895 return std::unique_ptr<IWorkload>();
896}
897
898std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
899 const WorkloadInfo& info) const
900{
901 return std::unique_ptr<IWorkload>();
902}
903
904std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
905 const WorkloadInfo& info) const
906{
907 return std::unique_ptr<IWorkload>();
908}
909
910std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
911 const WorkloadInfo& info) const
912{
913 return std::unique_ptr<IWorkload>();
914}
915
916std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
917 const WorkloadInfo& info) const
918{
919 return std::unique_ptr<IWorkload>();
920}
921
922std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
923 const WorkloadInfo& info) const
924{
925 return std::unique_ptr<IWorkload>();
926}
927
928std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
929 const WorkloadInfo& info) const
930{
931 return std::unique_ptr<IWorkload>();
932}
933
934std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
935 const WorkloadInfo& info) const
936{
937 return std::unique_ptr<IWorkload>();
938}
939
940std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
941 const WorkloadInfo& info) const
942{
943 return std::unique_ptr<IWorkload>();
944}
945
946std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
947 const WorkloadInfo& Info) const
948{
949 return std::unique_ptr<IWorkload>();
950}
951
952std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
953 const WorkloadInfo& info) const
954{
955 return std::unique_ptr<IWorkload>();
956}
957
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100958std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
959 const WorkloadInfo& info) const
960{
961 return std::unique_ptr<IWorkload>();
962}
963
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000964std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
965 const WorkloadInfo& info) const
966{
967 return std::unique_ptr<IWorkload>();
968}
969
970std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
971 const WorkloadInfo& info) const
972{
973 return std::unique_ptr<IWorkload>();
974}
975
976std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
977 const WorkloadInfo& info) const
978{
979 return std::unique_ptr<IWorkload>();
980}
981
982std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
983 const WorkloadInfo& info) const
984{
985 return std::unique_ptr<IWorkload>();
986}
987
988std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
989 const WorkloadInfo& info) const
990{
991 return std::unique_ptr<IWorkload>();
992}
993
994std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
995 const WorkloadInfo& Info) const
996{
997 return std::unique_ptr<IWorkload>();
998}
999
1000std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
1001 const WorkloadInfo& info) const
1002{
1003 return std::unique_ptr<IWorkload>();
1004}
1005
1006std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
1007 const WorkloadInfo& info) const
1008{
1009 return std::unique_ptr<IWorkload>();
1010}
1011
1012std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
1013 const WorkloadInfo& info) const
1014{
1015 return std::unique_ptr<IWorkload>();
1016}
1017
1018std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
1019 const WorkloadInfo& Info) const
1020{
1021 return std::unique_ptr<IWorkload>();
1022}
1023
1024std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
1025 const WorkloadInfo& info) const
1026{
1027 return std::unique_ptr<IWorkload>();
1028}
1029
1030std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
1031 const WorkloadInfo& info) const
1032{
1033 return std::unique_ptr<IWorkload>();
1034}
1035
1036std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
1037 const WorkloadInfo& info) const
1038{
1039 return std::unique_ptr<IWorkload>();
1040}
1041
1042std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
1043 const WorkloadInfo& info) const
1044{
1045 return std::unique_ptr<IWorkload>();
1046}
1047
1048std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
1049 const WorkloadInfo& info) const
1050{
1051 return std::unique_ptr<IWorkload>();
1052}
1053
1054std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
1055 const WorkloadInfo& info) const
1056{
1057 return std::unique_ptr<IWorkload>();
1058}
1059
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001060std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
1061 const WorkloadInfo& info) const
1062{
1063 return std::unique_ptr<IWorkload>();
1064}
1065
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001066std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
1067 const WorkloadInfo& Info) const
1068{
1069 return std::unique_ptr<IWorkload>();
1070}
1071
1072std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
1073 const WorkloadInfo& info) const
1074{
1075 return std::unique_ptr<IWorkload>();
1076}
1077
Sadik Armaganeff363d2019-04-05 15:25:46 +01001078std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
1079 const WorkloadInfo& info) const
1080{
1081 return std::unique_ptr<IWorkload>();
1082}
1083
surmeh013537c2c2018-05-18 16:31:43 +01001084}