blob: ffef5b4eb772228a3279c4f004244e44f179d047 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
Derek Lambertia9cca6a2019-03-25 15:41:58 +00007#include "WorkloadFactory.hpp"
8
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Layer.hpp>
11#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010012
David Beckb4540be2018-09-24 13:18:27 +010013#include <armnn/Types.hpp>
14#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000015#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000018#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000019#include <backendsCommon/IBackendInternal.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000023#include <boost/iterator/transform_iterator.hpp>
24
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000025#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000026#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000027
telsoa014fcda012018-03-09 14:13:49 +000028namespace armnn
29{
30
telsoa01c577f2c2018-08-31 09:22:23 +010031namespace
32{
telsoa01c577f2c2018-08-31 09:22:23 +010033
David Beck29c75de2018-10-23 13:35:58 +010034const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
35{
36 if (!type)
37 {
38 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010039 }
40
David Beck29c75de2018-10-23 13:35:58 +010041 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010042}
43
David Beck29c75de2018-10-23 13:35:58 +010044} // anonymous namespace
45
David Beck33f0ae02018-10-18 15:13:56 +010046bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010047 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010048 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010049 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000050{
David Beck33f0ae02018-10-18 15:13:56 +010051 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000052 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010053 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
54
David Beck111b5d92018-11-12 14:59:37 +000055 auto const& backendRegistry = BackendRegistryInstance();
56 if (!backendRegistry.IsBackendRegistered(backendId))
57 {
58 std::stringstream ss;
59 ss << connectableLayer.GetName() << " is not supported on " << backendId
60 << " because this backend is not registered.";
61
62 outReasonIfUnsupported = ss.str();
63 return false;
64 }
65
66 auto backendFactory = backendRegistry.GetFactory(backendId);
67 auto backendObject = backendFactory();
68 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010069
telsoa014fcda012018-03-09 14:13:49 +000070 switch(layer.GetType())
71 {
72 case LayerType::Activation:
73 {
74 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
75 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010076 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010077 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010078 OverrideDataType(input, dataType),
79 OverrideDataType(output, dataType),
80 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010081 reason);
telsoa014fcda012018-03-09 14:13:49 +000082 break;
83 }
84 case LayerType::Addition:
85 {
86 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
87 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
88 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010089 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010090 OverrideDataType(input0, dataType),
91 OverrideDataType(input1, dataType),
92 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010093 reason);
telsoa014fcda012018-03-09 14:13:49 +000094 break;
95 }
96 case LayerType::BatchNormalization:
97 {
98 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
99 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100100 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
101 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
102 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
103 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
104 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100105 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100106 OverrideDataType(input, dataType),
107 OverrideDataType(output, dataType),
108 OverrideDataType(mean, dataType),
109 OverrideDataType(var, dataType),
110 OverrideDataType(beta, dataType),
111 OverrideDataType(gamma, dataType),
112 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100113 reason);
telsoa014fcda012018-03-09 14:13:49 +0000114 break;
115 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000116 case LayerType::BatchToSpaceNd:
117 {
118 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
119 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
120 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
121
122 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
123 OverrideDataType(output, dataType),
124 cLayer->GetParameters(),
125 reason);
126 break;
127 }
telsoa014fcda012018-03-09 14:13:49 +0000128 case LayerType::Constant:
129 {
130 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100131 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100132 break;
133 }
134 case LayerType::ConvertFp16ToFp32:
135 {
136 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
137 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100138 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100139 break;
140 }
141 case LayerType::ConvertFp32ToFp16:
142 {
143 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
144 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100145 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000146 break;
147 }
148 case LayerType::Convolution2d:
149 {
150 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100151
152 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
153 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100154 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100155 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
156
arovir01a6824102018-08-28 17:40:45 +0100157 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100158
arovir01a6824102018-08-28 17:40:45 +0100159 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100160 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100161 if (descriptor.m_BiasEnabled)
162 {
David Beck5eec11d2018-10-04 15:43:17 +0100163 biases =
164 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100165 }
166
David Beck33f0ae02018-10-18 15:13:56 +0100167 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100168 input,
169 output,
170 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100171 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100172 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100173 reason);
telsoa014fcda012018-03-09 14:13:49 +0000174 break;
175 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000176 case LayerType::Debug:
177 {
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
180
181 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
182 OverrideDataType(output, dataType),
183 reason);
184 break;
185 }
telsoa014fcda012018-03-09 14:13:49 +0000186 case LayerType::DepthwiseConvolution2d:
187 {
188 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
190 dataType);
191 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
192 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
193
telsoa01c577f2c2018-08-31 09:22:23 +0100194 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100195
196 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100197 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100198 if (descriptor.m_BiasEnabled)
199 {
David Beck5eec11d2018-10-04 15:43:17 +0100200 biases =
201 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100202 }
telsoa01c577f2c2018-08-31 09:22:23 +0100203
David Beck33f0ae02018-10-18 15:13:56 +0100204 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100205 input,
206 output,
207 descriptor,
208 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100209 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100210 reason);
telsoa014fcda012018-03-09 14:13:49 +0000211 break;
212 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000213 case LayerType::Dequantize:
214 {
215 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
216 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
217
218 result = layerSupportObject->IsDequantizeSupported(OverrideDataType(input, dataType),
219 OverrideDataType(output, DataType::Float32),
220 reason);
221 break;
222 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000223 case LayerType::DetectionPostProcess:
224 {
225 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
226 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
227 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
228 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
229 result = layerSupportObject->IsDetectionPostProcessSupported(input0,
230 input1,
231 descriptor,
232 reason);
233 break;
234 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000235 case LayerType::Equal:
236 {
237 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
238 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
239 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
240 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
241 OverrideDataType(input1, dataType),
242 OverrideDataType(output, dataType),
243 reason);
244 break;
245 }
telsoa014fcda012018-03-09 14:13:49 +0000246 case LayerType::FakeQuantization:
247 {
248 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
249 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100250 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
251 cLayer->GetParameters(),
252 reason);
telsoa014fcda012018-03-09 14:13:49 +0000253 break;
254 }
255 case LayerType::Floor:
256 {
257 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
258 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100259 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
260 OverrideDataType(output, dataType),
261 reason);
telsoa014fcda012018-03-09 14:13:49 +0000262 break;
263 }
264 case LayerType::FullyConnected:
265 {
266 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
267 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100268 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
269 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
270
271 TensorInfo biasInfo;
272 const TensorInfo * biasInfoPtr = nullptr;
273 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
274 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
275 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
276
277 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
278 if (descriptor.m_BiasEnabled)
279 {
280 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
281 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
282 biasInfoPtr = &biasInfo;
283 }
284 else
285 {
286 // If biases are not enabled pass a dummy tensorinfo for the validation
287 switch(input.GetDataType())
288 {
289 case DataType::Float16:
290 {
291 biasInfoPtr = &dummyFloat16Bias;
292 break;
293 }
294 case DataType::Float32:
295 {
296 biasInfoPtr = &dummyFloat32Bias;
297 break;
298 }
299 case DataType::QuantisedAsymm8:
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100300 case DataType::QuantisedSymm16:
telsoa01c577f2c2018-08-31 09:22:23 +0100301 {
302 biasInfoPtr = &dummyQA8Bias;
303 break;
304 }
305 default:
306 {
307 BOOST_ASSERT_MSG(false, "Unexpected bias type");
308 }
309 }
310 }
311
David Beck33f0ae02018-10-18 15:13:56 +0100312 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100313 OverrideDataType(input, dataType),
314 OverrideDataType(output, dataType),
315 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
316 *biasInfoPtr,
317 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100318 reason);
telsoa014fcda012018-03-09 14:13:49 +0000319 break;
320 }
narpra01b89b05f2019-01-16 09:53:09 +0000321 case LayerType::Gather:
322 {
323 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
324 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
325 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
326 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100327 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000328 OverrideDataType(output, dataType),
329 reason);
330 break;
331 }
telsoa014fcda012018-03-09 14:13:49 +0000332 case LayerType::Input:
333 {
334 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100335 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000336 break;
337 }
338 case LayerType::L2Normalization:
339 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100340 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
341 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
342
telsoa014fcda012018-03-09 14:13:49 +0000343 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100344 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100345
David Beck33f0ae02018-10-18 15:13:56 +0100346 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100347 OverrideDataType(input, dataType),
348 OverrideDataType(output, dataType),
349 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100350 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100351 break;
352 }
353 case LayerType::Lstm:
354 {
355 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
356 const LstmDescriptor& descriptor = cLayer->GetParameters();
357
358 // All inputs.
359 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
360 dataType);
361 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
362 dataType);
363 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
364 dataType);
365 // All outputs
366 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
367 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
368 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
369 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
370
371 // Basic parameters
372 const TensorInfo& inputToForgetWeights
373 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
374 const TensorInfo& inputToCellWeights
375 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
376 const TensorInfo& inputToOutputWeights
377 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
378 const TensorInfo& recurrentToForgetWeights
379 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
380 const TensorInfo& recurrentToCellWeights
381 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
382 const TensorInfo& recurrentToOutputWeights
383 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
384 const TensorInfo& forgetGateBias
385 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
386 const TensorInfo& cellBias
387 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
388 const TensorInfo& outputGateBias
389 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
390
Jan Eilersd01a83c2019-07-03 18:20:40 +0100391 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100392
Jan Eilersd01a83c2019-07-03 18:20:40 +0100393 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
394 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
395 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
396 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
397 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
398 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
399 paramsInfo.m_ForgetGateBias = &forgetGateBias;
400 paramsInfo.m_CellBias = &cellBias;
401 paramsInfo.m_OutputGateBias = &outputGateBias;
402
403
404 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100405 TensorInfo optInputToInputWeights;
406 TensorInfo optRecurrentToInputWeights;
407 TensorInfo optCellToInputWeights;
408 TensorInfo optInputGateBias;
409 TensorInfo optProjectionWeights;
410 TensorInfo optProjectionBias;
411 TensorInfo optCellToForgetWeights;
412 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100413 TensorInfo optInputLayerNormWeights;
414 TensorInfo optForgetLayerNormWeights;
415 TensorInfo optCellLayerNormWeights;
416 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100417
418 if(!descriptor.m_CifgEnabled)
419 {
420 optInputToInputWeights =
421 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100422 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100423
424 optRecurrentToInputWeights =
425 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100426 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100427 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
428 {
429 optCellToInputWeights =
430 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100431 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100432 }
433 optInputGateBias =
434 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100435 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100436 }
437
438 if(descriptor.m_ProjectionEnabled)
439 {
440 optProjectionWeights =
441 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100442 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100443 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
444 {
445 optProjectionBias =
446 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100447 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100448 }
449 }
450
451 if(descriptor.m_PeepholeEnabled)
452 {
453 optCellToForgetWeights =
454 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100455 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100456 optCellToOutputWeights =
457 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100458 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100459 }
460
Jan Eilers38e05bd2019-06-26 13:10:09 +0100461 if(descriptor.m_LayerNormEnabled)
462 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100463 if (!descriptor.m_CifgEnabled)
464 {
465 optInputLayerNormWeights = OverrideDataType(
466 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
467 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
468 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100469
470 optForgetLayerNormWeights = OverrideDataType(
471 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100472 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100473
474 optCellLayerNormWeights = OverrideDataType(
475 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100476 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100477
478 optOutputLayerNormWeights = OverrideDataType(
479 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100480 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100481 }
482
David Beck33f0ae02018-10-18 15:13:56 +0100483 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100484 input,
485 outputStateIn,
486 cellStateIn,
487 scratchBuffer,
488 outputStateOut,
489 cellStateOut,
490 output,
491 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100492 paramsInfo,
493 reason);
telsoa014fcda012018-03-09 14:13:49 +0000494 break;
495 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000496 case LayerType::Maximum:
497 {
498 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
499 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
500 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
501
502 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
503 OverrideDataType(input1, dataType),
504 OverrideDataType(output, dataType),
505 reason);
506 break;
507 }
narpra01b89b05f2019-01-16 09:53:09 +0000508 case LayerType::MemCopy:
509 {
510 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
511 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000512
narpra01b89b05f2019-01-16 09:53:09 +0000513 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
514 OverrideDataType(output, dataType),
515 reason);
516 break;
517 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100518 case LayerType::MemImport:
519 {
520 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
521 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
522
523 result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
524 OverrideDataType(output, dataType),
525 reason);
526 break;
527 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100528 case LayerType::Merge:
529 {
530 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
531 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
532 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
533
534 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
535 OverrideDataType(input1, dataType),
536 OverrideDataType(output, dataType),
537 reason);
538 break;
539 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100540 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000541 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100542 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000543
telsoa01c577f2c2018-08-31 09:22:23 +0100544 // Get vector of all inputs.
545 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000546 {
telsoa01c577f2c2018-08-31 09:22:23 +0100547 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000548 };
telsoa01c577f2c2018-08-31 09:22:23 +0100549 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
550 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
551 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000552
telsoa01c577f2c2018-08-31 09:22:23 +0100553 auto getTensorInfoPtr = [](const TensorInfo& info)
554 {
555 return &info;
556 };
557 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
558 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
559 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000560
Nikhil Raj8599a412018-11-19 14:51:07 +0000561 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
562
Jim Flynne242f2d2019-05-22 14:24:13 +0100563 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
564
565
telsoa014fcda012018-03-09 14:13:49 +0000566 break;
567 }
568 case LayerType::Multiplication:
569 {
570 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
571 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100572 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100573 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100574 OverrideDataType(input0, dataType),
575 OverrideDataType(input1, dataType),
576 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100577 reason);
telsoa014fcda012018-03-09 14:13:49 +0000578 break;
579 }
580 case LayerType::Normalization:
581 {
582 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
583 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
584 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100585 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
586 OverrideDataType(output, dataType),
587 cLayer->GetParameters(),
588 reason);
telsoa014fcda012018-03-09 14:13:49 +0000589 break;
590 }
591 case LayerType::Output:
592 {
593 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100594 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000595 break;
596 }
597 case LayerType::Permute:
598 {
599 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
600 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
601 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100602 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
603 OverrideDataType(output, dataType),
604 cLayer->GetParameters(),
605 reason);
telsoa014fcda012018-03-09 14:13:49 +0000606 break;
607 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100608 case LayerType::Pad:
609 {
610 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
611 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
612 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100613 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100614 OverrideDataType(input, dataType),
615 OverrideDataType(output, dataType),
616 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100617 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100618 break;
619 }
telsoa014fcda012018-03-09 14:13:49 +0000620 case LayerType::Pooling2d:
621 {
622 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
623 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
624 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100625 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
626 OverrideDataType(output, dataType),
627 cLayer->GetParameters(),
628 reason);
telsoa014fcda012018-03-09 14:13:49 +0000629 break;
630 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000631 case LayerType::PreCompiled:
632 {
633 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
634 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
635 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
636 cLayer->GetParameters(),
637 reason);
638 break;
639 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000640 case LayerType::Quantize:
641 {
642 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
643 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
644 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
645 break;
646 }
James Conroyee18dc82019-07-17 11:27:46 +0100647 case LayerType::QuantizedLstm:
648 {
649 auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer);
650
651 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100652 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
653 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
654 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100655
656 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100657 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
658 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100659
660 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100661 QuantizedLstmInputParamsInfo paramsInfo;
662
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100663 paramsInfo.m_InputToInputWeights =
664 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
665 paramsInfo.m_InputToForgetWeights =
666 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
667 paramsInfo.m_InputToCellWeights =
668 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
669 paramsInfo.m_InputToOutputWeights =
670 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100671
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100672 paramsInfo.m_RecurrentToInputWeights =
673 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
674 paramsInfo.m_RecurrentToForgetWeights =
675 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
676 paramsInfo.m_RecurrentToCellWeights =
677 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
678 paramsInfo.m_RecurrentToOutputWeights =
679 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100680
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100681 paramsInfo.m_InputGateBias =
682 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
683 paramsInfo.m_ForgetGateBias =
684 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
685 paramsInfo.m_CellBias =
686 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
687 paramsInfo.m_OutputGateBias =
688 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100689
690 result = layerSupportObject->IsQuantizedLstmSupported(input,
691 previousCellStateIn,
692 previousOutputIn,
693 cellStateOut,
694 output,
695 paramsInfo,
696 reason);
697 break;
698 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100699 case LayerType::Division:
700 {
701 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
702 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
703 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100704 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100705 OverrideDataType(input0, dataType),
706 OverrideDataType(input1, dataType),
707 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100708 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100709 break;
710 }
telsoa014fcda012018-03-09 14:13:49 +0000711 case LayerType::Reshape:
712 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000713 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000714 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000715 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
716 cLayer->GetParameters(),
717 reason);
telsoa014fcda012018-03-09 14:13:49 +0000718 break;
719 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100720 case LayerType::Resize:
721 {
722 auto cLayer = boost::polymorphic_downcast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100723 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100724 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
725 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
726 OverrideDataType(output, dataType),
727 cLayer->GetParameters(),
728 reason);
729 break;
730 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000731 case LayerType::Rsqrt:
732 {
733 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
734 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
735 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
736 OverrideDataType(output, dataType),
737 reason);
738 break;
739 }
telsoa014fcda012018-03-09 14:13:49 +0000740 case LayerType::Softmax:
741 {
742 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
743 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100744 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100745 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
746 OverrideDataType(output, dataType),
747 cLayer->GetParameters(),
748 reason);
telsoa014fcda012018-03-09 14:13:49 +0000749 break;
750 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000751 case LayerType::SpaceToBatchNd:
752 {
753 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
754 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
755 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
756 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
757 OverrideDataType(output, dataType),
758 cLayer->GetParameters(),
759 reason);
760 break;
761 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100762 case LayerType::SpaceToDepth:
763 {
764 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
765
766 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
767 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
768
769 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
770 OverrideDataType(output, dataType),
771 cLayer->GetParameters(),
772 reason);
773 break;
774 }
telsoa014fcda012018-03-09 14:13:49 +0000775 case LayerType::Splitter:
776 {
777 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
778 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100779
780 // Get vector of all outputs.
781 auto getTensorInfo = [&dataType](const OutputSlot& slot)
782 {
783 return OverrideDataType(slot.GetTensorInfo(), dataType);
784 };
785 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
786 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
787 std::vector<TensorInfo> outputs(beginI, endI);
788
789 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
790
David Beck33f0ae02018-10-18 15:13:56 +0100791 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100792 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100793 cLayer->GetParameters(),
794 reason);
telsoa014fcda012018-03-09 14:13:49 +0000795 break;
796 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100797 case LayerType::Stack:
798 {
799 auto cLayer = boost::polymorphic_downcast<const StackLayer*>(&layer);
800
801 // Get vector of all inputs.
802 auto getTensorInfo = [&dataType](const InputSlot& slot)
803 {
804 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
805 };
806 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
807 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
808 std::vector<TensorInfo> inputs(beginI, endI);
809
810 auto getTensorInfoPtr = [](const TensorInfo& info)
811 {
812 return &info;
813 };
814 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
815 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
816 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
817
818 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
819
820 result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
821
822 break;
823 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000824 case LayerType::StridedSlice:
825 {
826 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
827 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
828 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
829 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
830 OverrideDataType(output, dataType),
831 cLayer->GetParameters(),
832 reason);
833 break;
834 }
David Beckc2044fe2018-09-05 15:00:38 +0100835 case LayerType::Subtraction:
836 {
837 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
838 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
839 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100840 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100841 OverrideDataType(input0, dataType),
842 OverrideDataType(input1, dataType),
843 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100844 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100845 break;
846 }
Sadik Armaganeff363d2019-04-05 15:25:46 +0100847 case LayerType::Switch:
848 {
849 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
850 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
851 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
852 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
853 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
854 OverrideDataType(input1, dataType),
855 OverrideDataType(output0, dataType),
856 OverrideDataType(output1, dataType),
857 reason);
858 break;
859 }
narpra0132b90462018-09-13 11:07:48 +0100860 case LayerType::Mean:
861 {
862 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
863 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
864 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100865 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100866 OverrideDataType(input, dataType),
867 OverrideDataType(output, dataType),
868 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100869 reason);
narpra0132b90462018-09-13 11:07:48 +0100870 break;
871 }
kevmay0190539692018-11-29 08:40:19 +0000872 case LayerType::Minimum:
873 {
874 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
875 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
876 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
877 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
878 OverrideDataType(input1, dataType),
879 OverrideDataType(output, dataType),
880 reason);
881 break;
882 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000883 case LayerType::Greater:
884 {
885 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
886 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
887 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000888 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
889 OverrideDataType(input1, dataType),
890 OverrideDataType(output, DataType::Boolean),
891 reason);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000892 break;
893 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +0100894 case LayerType::Prelu:
895 {
896 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
897 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
898 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
899 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
900 OverrideDataType(alpha, dataType),
901 OverrideDataType(output, dataType),
902 reason);
903 break;
904 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +0100905 case LayerType::TransposeConvolution2d:
906 {
907 auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
908
909 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
910 dataType);
911 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
912
913 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
914
915 Optional<TensorInfo> biases;
916 if (descriptor.m_BiasEnabled)
917 {
918 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
919 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
920 GetBiasTypeFromWeightsType(dataType));
921 }
922
923 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
924 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
925
926 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
927 output,
928 descriptor,
929 weights,
930 biases,
931 reason);
932
933 break;
934 }
telsoa014fcda012018-03-09 14:13:49 +0000935 default:
936 {
937 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100938 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000939 result = false;
940 break;
941 }
942 }
telsoa014fcda012018-03-09 14:13:49 +0000943 return result;
944}
945
David Beckdcb751f2018-10-03 11:42:42 +0100946bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100947 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100948 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000949{
David Beckdcb751f2018-10-03 11:42:42 +0100950 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100951 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000952}
953
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000954// Default Implementations
955std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
956 const WorkloadInfo& info) const
957{
958 return std::unique_ptr<IWorkload>();
959}
960
961std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
962 const WorkloadInfo& info) const
963{
964 return std::unique_ptr<IWorkload>();
965}
966
967std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
968 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
969{
970 return std::unique_ptr<IWorkload>();
971}
972
973std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
974 const WorkloadInfo& Info) const
975{
976 return std::unique_ptr<IWorkload>();
977}
978
Jim Flynne242f2d2019-05-22 14:24:13 +0100979std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
Jim Flynn4ed6c832019-05-20 11:02:46 +0100980 const WorkloadInfo& info) const
981{
982 return std::unique_ptr<IWorkload>();
983}
984
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000985std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
986 const WorkloadInfo& info) const
987{
988 return std::unique_ptr<IWorkload>();
989}
990
991std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
992 const WorkloadInfo& info) const
993{
994 return std::unique_ptr<IWorkload>();
995}
996
997std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
998 const WorkloadInfo& info) const
999{
1000 return std::unique_ptr<IWorkload>();
1001}
1002
1003std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
1004 const WorkloadInfo& info) const
1005{
1006 return std::unique_ptr<IWorkload>();
1007}
1008
1009std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
1010 const WorkloadInfo& info) const
1011{
1012 return std::unique_ptr<IWorkload>();
1013}
1014
1015std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
1016 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
1017{
1018 return std::unique_ptr<IWorkload>();
1019}
1020
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001021std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
1022 const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const
1023{
1024 return std::unique_ptr<IWorkload>();
1025}
1026
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001027std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
1028 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const
1029{
1030 return std::unique_ptr<IWorkload>();
1031}
1032
1033std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
1034 const WorkloadInfo& info) const
1035{
1036 return std::unique_ptr<IWorkload>();
1037}
1038
1039std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
1040 const WorkloadInfo& Info) const
1041{
1042 return std::unique_ptr<IWorkload>();
1043}
1044
1045std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
1046 const WorkloadInfo& info) const
1047{
1048 return std::unique_ptr<IWorkload>();
1049}
1050
1051std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
1052 const WorkloadInfo& info) const
1053{
1054 return std::unique_ptr<IWorkload>();
1055}
1056
1057std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
1058 const WorkloadInfo& info) const
1059{
1060 return std::unique_ptr<IWorkload>();
1061}
1062
1063std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
1064 const WorkloadInfo& info) const
1065{
1066 return std::unique_ptr<IWorkload>();
1067}
1068
1069std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
1070 const WorkloadInfo& info) const
1071{
1072 return std::unique_ptr<IWorkload>();
1073}
1074
1075std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
1076 const WorkloadInfo& info) const
1077{
1078 return std::unique_ptr<IWorkload>();
1079}
1080
1081std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
1082 const WorkloadInfo& info) const
1083{
1084 return std::unique_ptr<IWorkload>();
1085}
1086
1087std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
1088 const WorkloadInfo& info) const
1089{
1090 return std::unique_ptr<IWorkload>();
1091}
1092
1093std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
1094 const WorkloadInfo& Info) const
1095{
1096 return std::unique_ptr<IWorkload>();
1097}
1098
1099std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
1100 const WorkloadInfo& info) const
1101{
1102 return std::unique_ptr<IWorkload>();
1103}
1104
Derek Lambertif674aa02019-08-01 15:56:25 +01001105std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
1106 const WorkloadInfo& info) const
1107{
1108 return std::unique_ptr<IWorkload>();
1109}
1110
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001111std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
1112 const WorkloadInfo& info) const
1113{
1114 return std::unique_ptr<IWorkload>();
1115}
1116
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001117std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
1118 const WorkloadInfo& info) const
1119{
1120 return std::unique_ptr<IWorkload>();
1121}
1122
1123std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
1124 const WorkloadInfo& info) const
1125{
1126 return std::unique_ptr<IWorkload>();
1127}
1128
1129std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
1130 const WorkloadInfo& info) const
1131{
1132 return std::unique_ptr<IWorkload>();
1133}
1134
1135std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
1136 const WorkloadInfo& info) const
1137{
1138 return std::unique_ptr<IWorkload>();
1139}
1140
1141std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
1142 const WorkloadInfo& info) const
1143{
1144 return std::unique_ptr<IWorkload>();
1145}
1146
1147std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
1148 const WorkloadInfo& Info) const
1149{
1150 return std::unique_ptr<IWorkload>();
1151}
1152
1153std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
1154 const WorkloadInfo& info) const
1155{
1156 return std::unique_ptr<IWorkload>();
1157}
1158
1159std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
1160 const WorkloadInfo& info) const
1161{
1162 return std::unique_ptr<IWorkload>();
1163}
1164
1165std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
1166 const WorkloadInfo& info) const
1167{
1168 return std::unique_ptr<IWorkload>();
1169}
1170
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001171std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
1172 const WorkloadInfo &info) const
1173{
1174 return std::unique_ptr<IWorkload>();
1175}
1176
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001177std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
1178 const WorkloadInfo& Info) const
1179{
1180 return std::unique_ptr<IWorkload>();
1181}
1182
James Conroyee18dc82019-07-17 11:27:46 +01001183std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
1184 const WorkloadInfo& info) const
1185{
1186 return std::unique_ptr<IWorkload>();
1187}
1188
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001189std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
1190 const WorkloadInfo& info) const
1191{
1192 return std::unique_ptr<IWorkload>();
1193}
1194
1195std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
1196 const WorkloadInfo& info) const
1197{
1198 return std::unique_ptr<IWorkload>();
1199}
1200
Teresa Charlina9075df2019-06-27 15:41:57 +01001201std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
1202 const WorkloadInfo& info) const
1203{
1204 return std::unique_ptr<IWorkload>();
1205}
1206
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001207std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
1208 const WorkloadInfo& info) const
1209{
1210 return std::unique_ptr<IWorkload>();
1211}
1212
1213std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
1214 const WorkloadInfo& info) const
1215{
1216 return std::unique_ptr<IWorkload>();
1217}
1218
1219std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
1220 const WorkloadInfo& info) const
1221{
1222 return std::unique_ptr<IWorkload>();
1223}
1224
1225std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
1226 const WorkloadInfo& info) const
1227{
1228 return std::unique_ptr<IWorkload>();
1229}
1230
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001231std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
1232 const WorkloadInfo& info) const
1233{
1234 return std::unique_ptr<IWorkload>();
1235}
1236
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001237std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor,
1238 const WorkloadInfo& info) const
1239{
1240 return std::unique_ptr<IWorkload>();
1241}
1242
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001243std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
1244 const WorkloadInfo& Info) const
1245{
1246 return std::unique_ptr<IWorkload>();
1247}
1248
1249std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
1250 const WorkloadInfo& info) const
1251{
1252 return std::unique_ptr<IWorkload>();
1253}
1254
Sadik Armaganeff363d2019-04-05 15:25:46 +01001255std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
1256 const WorkloadInfo& info) const
1257{
1258 return std::unique_ptr<IWorkload>();
1259}
1260
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001261std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
1262 const TransposeConvolution2dQueueDescriptor& descriptor,
1263 const WorkloadInfo& info) const
1264{
1265 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001266}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001267
1268} // namepsace armnn