blob: 17bd98b34961d78b58958b195f564d6e247c122b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
Derek Lambertia9cca6a2019-03-25 15:41:58 +00007#include "WorkloadFactory.hpp"
8
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Layer.hpp>
11#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010012
David Beckb4540be2018-09-24 13:18:27 +010013#include <armnn/Types.hpp>
14#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000015#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000018#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000019#include <backendsCommon/IBackendInternal.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000023#include <boost/iterator/transform_iterator.hpp>
24
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000025#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000026#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000027
telsoa014fcda012018-03-09 14:13:49 +000028namespace armnn
29{
30
telsoa01c577f2c2018-08-31 09:22:23 +010031namespace
32{
telsoa01c577f2c2018-08-31 09:22:23 +010033
David Beck29c75de2018-10-23 13:35:58 +010034const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
35{
36 if (!type)
37 {
38 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010039 }
40
David Beck29c75de2018-10-23 13:35:58 +010041 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010042}
43
David Beck29c75de2018-10-23 13:35:58 +010044} // anonymous namespace
45
David Beck33f0ae02018-10-18 15:13:56 +010046bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010047 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010048 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010049 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000050{
David Beck33f0ae02018-10-18 15:13:56 +010051 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000052 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010053 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
54
David Beck111b5d92018-11-12 14:59:37 +000055 auto const& backendRegistry = BackendRegistryInstance();
56 if (!backendRegistry.IsBackendRegistered(backendId))
57 {
58 std::stringstream ss;
59 ss << connectableLayer.GetName() << " is not supported on " << backendId
60 << " because this backend is not registered.";
61
62 outReasonIfUnsupported = ss.str();
63 return false;
64 }
65
66 auto backendFactory = backendRegistry.GetFactory(backendId);
67 auto backendObject = backendFactory();
68 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010069
telsoa014fcda012018-03-09 14:13:49 +000070 switch(layer.GetType())
71 {
Kevin May868eb142019-09-04 17:29:31 +010072 case LayerType::Abs:
73 {
74 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
75 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
76 result = layerSupportObject->IsAbsSupported(OverrideDataType(input, dataType),
77 OverrideDataType(output, dataType),
78 reason);
79 break;
80 }
telsoa014fcda012018-03-09 14:13:49 +000081 case LayerType::Activation:
82 {
83 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
84 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010085 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010086 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010087 OverrideDataType(input, dataType),
88 OverrideDataType(output, dataType),
89 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010090 reason);
telsoa014fcda012018-03-09 14:13:49 +000091 break;
92 }
93 case LayerType::Addition:
94 {
95 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
96 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
97 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010098 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010099 OverrideDataType(input0, dataType),
100 OverrideDataType(input1, dataType),
101 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100102 reason);
telsoa014fcda012018-03-09 14:13:49 +0000103 break;
104 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100105 case LayerType::ArgMinMax:
106 {
107 auto cLayer = boost::polymorphic_downcast<const ArgMinMaxLayer*>(&layer);
108 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
109
110 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
111 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
112 result = layerSupportObject->IsArgMinMaxSupported(
113 OverrideDataType(input, dataType),
114 OverrideDataType(output, dataType),
115 descriptor,
116 reason);
117 break;
118 }
telsoa014fcda012018-03-09 14:13:49 +0000119 case LayerType::BatchNormalization:
120 {
121 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
122 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100123 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
124 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
125 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
126 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
127 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100128 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100129 OverrideDataType(input, dataType),
130 OverrideDataType(output, dataType),
131 OverrideDataType(mean, dataType),
132 OverrideDataType(var, dataType),
133 OverrideDataType(beta, dataType),
134 OverrideDataType(gamma, dataType),
135 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100136 reason);
telsoa014fcda012018-03-09 14:13:49 +0000137 break;
138 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000139 case LayerType::BatchToSpaceNd:
140 {
141 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
142 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
143 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
144
145 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
146 OverrideDataType(output, dataType),
147 cLayer->GetParameters(),
148 reason);
149 break;
150 }
telsoa014fcda012018-03-09 14:13:49 +0000151 case LayerType::Constant:
152 {
153 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100154 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100155 break;
156 }
157 case LayerType::ConvertFp16ToFp32:
158 {
159 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
160 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100161 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100162 break;
163 }
164 case LayerType::ConvertFp32ToFp16:
165 {
166 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
167 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100168 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000169 break;
170 }
171 case LayerType::Convolution2d:
172 {
173 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100174
175 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
176 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100177 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100178 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
179
arovir01a6824102018-08-28 17:40:45 +0100180 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100181
arovir01a6824102018-08-28 17:40:45 +0100182 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100183 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100184 if (descriptor.m_BiasEnabled)
185 {
David Beck5eec11d2018-10-04 15:43:17 +0100186 biases =
187 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100188 }
189
David Beck33f0ae02018-10-18 15:13:56 +0100190 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100191 input,
192 output,
193 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100194 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100195 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100196 reason);
telsoa014fcda012018-03-09 14:13:49 +0000197 break;
198 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000199 case LayerType::Debug:
200 {
201 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
202 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
203
204 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
205 OverrideDataType(output, dataType),
206 reason);
207 break;
208 }
telsoa014fcda012018-03-09 14:13:49 +0000209 case LayerType::DepthwiseConvolution2d:
210 {
211 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100212 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
213 dataType);
214 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
215 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
216
telsoa01c577f2c2018-08-31 09:22:23 +0100217 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100218
219 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100220 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100221 if (descriptor.m_BiasEnabled)
222 {
David Beck5eec11d2018-10-04 15:43:17 +0100223 biases =
224 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100225 }
telsoa01c577f2c2018-08-31 09:22:23 +0100226
David Beck33f0ae02018-10-18 15:13:56 +0100227 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100228 input,
229 output,
230 descriptor,
231 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100232 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100233 reason);
telsoa014fcda012018-03-09 14:13:49 +0000234 break;
235 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000236 case LayerType::Dequantize:
237 {
238 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
239 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
240
241 result = layerSupportObject->IsDequantizeSupported(OverrideDataType(input, dataType),
242 OverrideDataType(output, DataType::Float32),
243 reason);
244 break;
245 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000246 case LayerType::DetectionPostProcess:
247 {
248 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
249 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
250 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
251 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
252 result = layerSupportObject->IsDetectionPostProcessSupported(input0,
253 input1,
254 descriptor,
255 reason);
256 break;
257 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000258 case LayerType::Equal:
259 {
260 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
261 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
262 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
263 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
264 OverrideDataType(input1, dataType),
265 OverrideDataType(output, dataType),
266 reason);
267 break;
268 }
telsoa014fcda012018-03-09 14:13:49 +0000269 case LayerType::FakeQuantization:
270 {
271 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
272 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100273 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
274 cLayer->GetParameters(),
275 reason);
telsoa014fcda012018-03-09 14:13:49 +0000276 break;
277 }
278 case LayerType::Floor:
279 {
280 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
281 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100282 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
283 OverrideDataType(output, dataType),
284 reason);
telsoa014fcda012018-03-09 14:13:49 +0000285 break;
286 }
287 case LayerType::FullyConnected:
288 {
289 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
290 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100291 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
292 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
293
294 TensorInfo biasInfo;
295 const TensorInfo * biasInfoPtr = nullptr;
296 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
297 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
298 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
299
300 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
301 if (descriptor.m_BiasEnabled)
302 {
303 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
304 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
305 biasInfoPtr = &biasInfo;
306 }
307 else
308 {
309 // If biases are not enabled pass a dummy tensorinfo for the validation
310 switch(input.GetDataType())
311 {
312 case DataType::Float16:
313 {
314 biasInfoPtr = &dummyFloat16Bias;
315 break;
316 }
317 case DataType::Float32:
318 {
319 biasInfoPtr = &dummyFloat32Bias;
320 break;
321 }
322 case DataType::QuantisedAsymm8:
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100323 case DataType::QuantisedSymm16:
telsoa01c577f2c2018-08-31 09:22:23 +0100324 {
325 biasInfoPtr = &dummyQA8Bias;
326 break;
327 }
328 default:
329 {
330 BOOST_ASSERT_MSG(false, "Unexpected bias type");
331 }
332 }
333 }
334
David Beck33f0ae02018-10-18 15:13:56 +0100335 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100336 OverrideDataType(input, dataType),
337 OverrideDataType(output, dataType),
338 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
339 *biasInfoPtr,
340 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100341 reason);
telsoa014fcda012018-03-09 14:13:49 +0000342 break;
343 }
narpra01b89b05f2019-01-16 09:53:09 +0000344 case LayerType::Gather:
345 {
346 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
347 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
348 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
349 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100350 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000351 OverrideDataType(output, dataType),
352 reason);
353 break;
354 }
telsoa014fcda012018-03-09 14:13:49 +0000355 case LayerType::Input:
356 {
357 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100358 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000359 break;
360 }
361 case LayerType::L2Normalization:
362 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100363 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
364 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
365
telsoa014fcda012018-03-09 14:13:49 +0000366 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100367 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100368
David Beck33f0ae02018-10-18 15:13:56 +0100369 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100370 OverrideDataType(input, dataType),
371 OverrideDataType(output, dataType),
372 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100373 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100374 break;
375 }
376 case LayerType::Lstm:
377 {
378 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
379 const LstmDescriptor& descriptor = cLayer->GetParameters();
380
381 // All inputs.
382 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
383 dataType);
384 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
385 dataType);
386 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
387 dataType);
388 // All outputs
389 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
390 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
391 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
392 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
393
394 // Basic parameters
395 const TensorInfo& inputToForgetWeights
396 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
397 const TensorInfo& inputToCellWeights
398 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
399 const TensorInfo& inputToOutputWeights
400 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
401 const TensorInfo& recurrentToForgetWeights
402 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
403 const TensorInfo& recurrentToCellWeights
404 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
405 const TensorInfo& recurrentToOutputWeights
406 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
407 const TensorInfo& forgetGateBias
408 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
409 const TensorInfo& cellBias
410 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
411 const TensorInfo& outputGateBias
412 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
413
Jan Eilersd01a83c2019-07-03 18:20:40 +0100414 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100415
Jan Eilersd01a83c2019-07-03 18:20:40 +0100416 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
417 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
418 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
419 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
420 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
421 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
422 paramsInfo.m_ForgetGateBias = &forgetGateBias;
423 paramsInfo.m_CellBias = &cellBias;
424 paramsInfo.m_OutputGateBias = &outputGateBias;
425
426
427 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100428 TensorInfo optInputToInputWeights;
429 TensorInfo optRecurrentToInputWeights;
430 TensorInfo optCellToInputWeights;
431 TensorInfo optInputGateBias;
432 TensorInfo optProjectionWeights;
433 TensorInfo optProjectionBias;
434 TensorInfo optCellToForgetWeights;
435 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100436 TensorInfo optInputLayerNormWeights;
437 TensorInfo optForgetLayerNormWeights;
438 TensorInfo optCellLayerNormWeights;
439 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100440
441 if(!descriptor.m_CifgEnabled)
442 {
443 optInputToInputWeights =
444 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100445 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100446
447 optRecurrentToInputWeights =
448 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100449 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100450 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
451 {
452 optCellToInputWeights =
453 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100454 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100455 }
456 optInputGateBias =
457 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100458 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100459 }
460
461 if(descriptor.m_ProjectionEnabled)
462 {
463 optProjectionWeights =
464 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100465 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100466 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
467 {
468 optProjectionBias =
469 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100470 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100471 }
472 }
473
474 if(descriptor.m_PeepholeEnabled)
475 {
476 optCellToForgetWeights =
477 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100478 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100479 optCellToOutputWeights =
480 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100481 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100482 }
483
Jan Eilers38e05bd2019-06-26 13:10:09 +0100484 if(descriptor.m_LayerNormEnabled)
485 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100486 if (!descriptor.m_CifgEnabled)
487 {
488 optInputLayerNormWeights = OverrideDataType(
489 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
490 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
491 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100492
493 optForgetLayerNormWeights = OverrideDataType(
494 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100495 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100496
497 optCellLayerNormWeights = OverrideDataType(
498 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100499 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100500
501 optOutputLayerNormWeights = OverrideDataType(
502 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100503 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100504 }
505
David Beck33f0ae02018-10-18 15:13:56 +0100506 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100507 input,
508 outputStateIn,
509 cellStateIn,
510 scratchBuffer,
511 outputStateOut,
512 cellStateOut,
513 output,
514 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100515 paramsInfo,
516 reason);
telsoa014fcda012018-03-09 14:13:49 +0000517 break;
518 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000519 case LayerType::Maximum:
520 {
521 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
522 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
523 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
524
525 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
526 OverrideDataType(input1, dataType),
527 OverrideDataType(output, dataType),
528 reason);
529 break;
530 }
narpra01b89b05f2019-01-16 09:53:09 +0000531 case LayerType::MemCopy:
532 {
533 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
534 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000535
narpra01b89b05f2019-01-16 09:53:09 +0000536 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
537 OverrideDataType(output, dataType),
538 reason);
539 break;
540 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100541 case LayerType::MemImport:
542 {
543 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
544 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
545
546 result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
547 OverrideDataType(output, dataType),
548 reason);
549 break;
550 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100551 case LayerType::Merge:
552 {
553 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
554 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
555 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
556
557 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
558 OverrideDataType(input1, dataType),
559 OverrideDataType(output, dataType),
560 reason);
561 break;
562 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100563 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000564 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100565 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000566
telsoa01c577f2c2018-08-31 09:22:23 +0100567 // Get vector of all inputs.
568 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000569 {
telsoa01c577f2c2018-08-31 09:22:23 +0100570 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000571 };
telsoa01c577f2c2018-08-31 09:22:23 +0100572 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
573 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
574 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000575
telsoa01c577f2c2018-08-31 09:22:23 +0100576 auto getTensorInfoPtr = [](const TensorInfo& info)
577 {
578 return &info;
579 };
580 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
581 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
582 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000583
Nikhil Raj8599a412018-11-19 14:51:07 +0000584 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
585
Jim Flynne242f2d2019-05-22 14:24:13 +0100586 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
587
588
telsoa014fcda012018-03-09 14:13:49 +0000589 break;
590 }
591 case LayerType::Multiplication:
592 {
593 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
594 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100595 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100596 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100597 OverrideDataType(input0, dataType),
598 OverrideDataType(input1, dataType),
599 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100600 reason);
telsoa014fcda012018-03-09 14:13:49 +0000601 break;
602 }
603 case LayerType::Normalization:
604 {
605 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
606 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
607 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100608 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
609 OverrideDataType(output, dataType),
610 cLayer->GetParameters(),
611 reason);
telsoa014fcda012018-03-09 14:13:49 +0000612 break;
613 }
614 case LayerType::Output:
615 {
616 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100617 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000618 break;
619 }
620 case LayerType::Permute:
621 {
622 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
623 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
624 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100625 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
626 OverrideDataType(output, dataType),
627 cLayer->GetParameters(),
628 reason);
telsoa014fcda012018-03-09 14:13:49 +0000629 break;
630 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100631 case LayerType::Pad:
632 {
633 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
634 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
635 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100636 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100637 OverrideDataType(input, dataType),
638 OverrideDataType(output, dataType),
639 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100640 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100641 break;
642 }
telsoa014fcda012018-03-09 14:13:49 +0000643 case LayerType::Pooling2d:
644 {
645 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
646 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
647 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100648 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
649 OverrideDataType(output, dataType),
650 cLayer->GetParameters(),
651 reason);
telsoa014fcda012018-03-09 14:13:49 +0000652 break;
653 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000654 case LayerType::PreCompiled:
655 {
656 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
657 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
658 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
659 cLayer->GetParameters(),
660 reason);
661 break;
662 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000663 case LayerType::Quantize:
664 {
665 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
666 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
667 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
668 break;
669 }
James Conroyee18dc82019-07-17 11:27:46 +0100670 case LayerType::QuantizedLstm:
671 {
672 auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer);
673
674 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100675 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
676 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
677 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100678
679 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100680 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
681 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100682
683 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100684 QuantizedLstmInputParamsInfo paramsInfo;
685
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100686 paramsInfo.m_InputToInputWeights =
687 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
688 paramsInfo.m_InputToForgetWeights =
689 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
690 paramsInfo.m_InputToCellWeights =
691 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
692 paramsInfo.m_InputToOutputWeights =
693 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100694
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100695 paramsInfo.m_RecurrentToInputWeights =
696 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
697 paramsInfo.m_RecurrentToForgetWeights =
698 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
699 paramsInfo.m_RecurrentToCellWeights =
700 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
701 paramsInfo.m_RecurrentToOutputWeights =
702 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100703
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100704 paramsInfo.m_InputGateBias =
705 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
706 paramsInfo.m_ForgetGateBias =
707 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
708 paramsInfo.m_CellBias =
709 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
710 paramsInfo.m_OutputGateBias =
711 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100712
713 result = layerSupportObject->IsQuantizedLstmSupported(input,
714 previousCellStateIn,
715 previousOutputIn,
716 cellStateOut,
717 output,
718 paramsInfo,
719 reason);
720 break;
721 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100722 case LayerType::Division:
723 {
724 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
725 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
726 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100727 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100728 OverrideDataType(input0, dataType),
729 OverrideDataType(input1, dataType),
730 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100731 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100732 break;
733 }
telsoa014fcda012018-03-09 14:13:49 +0000734 case LayerType::Reshape:
735 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000736 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000737 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000738 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
739 cLayer->GetParameters(),
740 reason);
telsoa014fcda012018-03-09 14:13:49 +0000741 break;
742 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100743 case LayerType::Resize:
744 {
745 auto cLayer = boost::polymorphic_downcast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100746 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100747 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
748 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
749 OverrideDataType(output, dataType),
750 cLayer->GetParameters(),
751 reason);
752 break;
753 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000754 case LayerType::Rsqrt:
755 {
756 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
757 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
758 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
759 OverrideDataType(output, dataType),
760 reason);
761 break;
762 }
telsoa014fcda012018-03-09 14:13:49 +0000763 case LayerType::Softmax:
764 {
765 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
766 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100767 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100768 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
769 OverrideDataType(output, dataType),
770 cLayer->GetParameters(),
771 reason);
telsoa014fcda012018-03-09 14:13:49 +0000772 break;
773 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000774 case LayerType::SpaceToBatchNd:
775 {
776 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
777 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
778 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
779 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
780 OverrideDataType(output, dataType),
781 cLayer->GetParameters(),
782 reason);
783 break;
784 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100785 case LayerType::SpaceToDepth:
786 {
787 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
788
789 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
790 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
791
792 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
793 OverrideDataType(output, dataType),
794 cLayer->GetParameters(),
795 reason);
796 break;
797 }
telsoa014fcda012018-03-09 14:13:49 +0000798 case LayerType::Splitter:
799 {
800 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
801 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100802
803 // Get vector of all outputs.
804 auto getTensorInfo = [&dataType](const OutputSlot& slot)
805 {
806 return OverrideDataType(slot.GetTensorInfo(), dataType);
807 };
808 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
809 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
810 std::vector<TensorInfo> outputs(beginI, endI);
811
812 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
813
David Beck33f0ae02018-10-18 15:13:56 +0100814 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100815 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100816 cLayer->GetParameters(),
817 reason);
telsoa014fcda012018-03-09 14:13:49 +0000818 break;
819 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100820 case LayerType::Stack:
821 {
822 auto cLayer = boost::polymorphic_downcast<const StackLayer*>(&layer);
823
824 // Get vector of all inputs.
825 auto getTensorInfo = [&dataType](const InputSlot& slot)
826 {
827 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
828 };
829 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
830 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
831 std::vector<TensorInfo> inputs(beginI, endI);
832
833 auto getTensorInfoPtr = [](const TensorInfo& info)
834 {
835 return &info;
836 };
837 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
838 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
839 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
840
841 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
842
843 result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
844
845 break;
846 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000847 case LayerType::StridedSlice:
848 {
849 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
850 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
851 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
852 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
853 OverrideDataType(output, dataType),
854 cLayer->GetParameters(),
855 reason);
856 break;
857 }
David Beckc2044fe2018-09-05 15:00:38 +0100858 case LayerType::Subtraction:
859 {
860 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
861 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
862 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100863 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100864 OverrideDataType(input0, dataType),
865 OverrideDataType(input1, dataType),
866 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100867 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100868 break;
869 }
Sadik Armaganeff363d2019-04-05 15:25:46 +0100870 case LayerType::Switch:
871 {
872 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
873 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
874 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
875 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
876 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
877 OverrideDataType(input1, dataType),
878 OverrideDataType(output0, dataType),
879 OverrideDataType(output1, dataType),
880 reason);
881 break;
882 }
narpra0132b90462018-09-13 11:07:48 +0100883 case LayerType::Mean:
884 {
885 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
886 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
887 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100888 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100889 OverrideDataType(input, dataType),
890 OverrideDataType(output, dataType),
891 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100892 reason);
narpra0132b90462018-09-13 11:07:48 +0100893 break;
894 }
kevmay0190539692018-11-29 08:40:19 +0000895 case LayerType::Minimum:
896 {
897 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
898 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
899 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
900 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
901 OverrideDataType(input1, dataType),
902 OverrideDataType(output, dataType),
903 reason);
904 break;
905 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000906 case LayerType::Greater:
907 {
908 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
909 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
910 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000911 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
912 OverrideDataType(input1, dataType),
913 OverrideDataType(output, DataType::Boolean),
914 reason);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000915 break;
916 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +0100917 case LayerType::Prelu:
918 {
919 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
920 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
921 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
922 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
923 OverrideDataType(alpha, dataType),
924 OverrideDataType(output, dataType),
925 reason);
926 break;
927 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +0100928 case LayerType::TransposeConvolution2d:
929 {
930 auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
931
932 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
933 dataType);
934 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
935
936 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
937
938 Optional<TensorInfo> biases;
939 if (descriptor.m_BiasEnabled)
940 {
941 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
942 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
943 GetBiasTypeFromWeightsType(dataType));
944 }
945
946 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
947 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
948
949 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
950 output,
951 descriptor,
952 weights,
953 biases,
954 reason);
955
956 break;
957 }
telsoa014fcda012018-03-09 14:13:49 +0000958 default:
959 {
960 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100961 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000962 result = false;
963 break;
964 }
965 }
telsoa014fcda012018-03-09 14:13:49 +0000966 return result;
967}
968
David Beckdcb751f2018-10-03 11:42:42 +0100969bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100970 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100971 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000972{
David Beckdcb751f2018-10-03 11:42:42 +0100973 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100974 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000975}
976
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000977// Default Implementations
Kevin May868eb142019-09-04 17:29:31 +0100978std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor,
979 const WorkloadInfo& info) const
980{
981 return std::unique_ptr<IWorkload>();
982}
983
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000984std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
985 const WorkloadInfo& info) const
986{
987 return std::unique_ptr<IWorkload>();
988}
989
990std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
991 const WorkloadInfo& info) const
992{
993 return std::unique_ptr<IWorkload>();
994}
995
Nikhil Rajee391d52019-09-05 17:50:44 +0100996std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
997 const WorkloadInfo& info) const
998{
999 return std::unique_ptr<IWorkload>();
1000}
1001
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001002std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
1003 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
1004{
1005 return std::unique_ptr<IWorkload>();
1006}
1007
1008std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
1009 const WorkloadInfo& Info) const
1010{
1011 return std::unique_ptr<IWorkload>();
1012}
1013
Jim Flynne242f2d2019-05-22 14:24:13 +01001014std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
Jim Flynn4ed6c832019-05-20 11:02:46 +01001015 const WorkloadInfo& info) const
1016{
1017 return std::unique_ptr<IWorkload>();
1018}
1019
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001020std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
1021 const WorkloadInfo& info) const
1022{
1023 return std::unique_ptr<IWorkload>();
1024}
1025
1026std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
1027 const WorkloadInfo& info) const
1028{
1029 return std::unique_ptr<IWorkload>();
1030}
1031
1032std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
1033 const WorkloadInfo& info) const
1034{
1035 return std::unique_ptr<IWorkload>();
1036}
1037
1038std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
1039 const WorkloadInfo& info) const
1040{
1041 return std::unique_ptr<IWorkload>();
1042}
1043
1044std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
1045 const WorkloadInfo& info) const
1046{
1047 return std::unique_ptr<IWorkload>();
1048}
1049
1050std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
1051 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
1052{
1053 return std::unique_ptr<IWorkload>();
1054}
1055
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001056std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
1057 const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const
1058{
1059 return std::unique_ptr<IWorkload>();
1060}
1061
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001062std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
1063 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const
1064{
1065 return std::unique_ptr<IWorkload>();
1066}
1067
1068std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
1069 const WorkloadInfo& info) const
1070{
1071 return std::unique_ptr<IWorkload>();
1072}
1073
1074std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
1075 const WorkloadInfo& Info) const
1076{
1077 return std::unique_ptr<IWorkload>();
1078}
1079
1080std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
1081 const WorkloadInfo& info) const
1082{
1083 return std::unique_ptr<IWorkload>();
1084}
1085
1086std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
1087 const WorkloadInfo& info) const
1088{
1089 return std::unique_ptr<IWorkload>();
1090}
1091
1092std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
1093 const WorkloadInfo& info) const
1094{
1095 return std::unique_ptr<IWorkload>();
1096}
1097
1098std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
1099 const WorkloadInfo& info) const
1100{
1101 return std::unique_ptr<IWorkload>();
1102}
1103
1104std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
1105 const WorkloadInfo& info) const
1106{
1107 return std::unique_ptr<IWorkload>();
1108}
1109
1110std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
1111 const WorkloadInfo& info) const
1112{
1113 return std::unique_ptr<IWorkload>();
1114}
1115
1116std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
1117 const WorkloadInfo& info) const
1118{
1119 return std::unique_ptr<IWorkload>();
1120}
1121
1122std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
1123 const WorkloadInfo& info) const
1124{
1125 return std::unique_ptr<IWorkload>();
1126}
1127
1128std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
1129 const WorkloadInfo& Info) const
1130{
1131 return std::unique_ptr<IWorkload>();
1132}
1133
1134std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
1135 const WorkloadInfo& info) const
1136{
1137 return std::unique_ptr<IWorkload>();
1138}
1139
Derek Lambertif674aa02019-08-01 15:56:25 +01001140std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
1141 const WorkloadInfo& info) const
1142{
1143 return std::unique_ptr<IWorkload>();
1144}
1145
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001146std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
1147 const WorkloadInfo& info) const
1148{
1149 return std::unique_ptr<IWorkload>();
1150}
1151
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001152std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
1153 const WorkloadInfo& info) const
1154{
1155 return std::unique_ptr<IWorkload>();
1156}
1157
1158std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
1159 const WorkloadInfo& info) const
1160{
1161 return std::unique_ptr<IWorkload>();
1162}
1163
1164std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
1165 const WorkloadInfo& info) const
1166{
1167 return std::unique_ptr<IWorkload>();
1168}
1169
1170std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
1171 const WorkloadInfo& info) const
1172{
1173 return std::unique_ptr<IWorkload>();
1174}
1175
1176std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
1177 const WorkloadInfo& info) const
1178{
1179 return std::unique_ptr<IWorkload>();
1180}
1181
1182std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
1183 const WorkloadInfo& Info) const
1184{
1185 return std::unique_ptr<IWorkload>();
1186}
1187
1188std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
1189 const WorkloadInfo& info) const
1190{
1191 return std::unique_ptr<IWorkload>();
1192}
1193
1194std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
1195 const WorkloadInfo& info) const
1196{
1197 return std::unique_ptr<IWorkload>();
1198}
1199
1200std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
1201 const WorkloadInfo& info) const
1202{
1203 return std::unique_ptr<IWorkload>();
1204}
1205
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001206std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
1207 const WorkloadInfo &info) const
1208{
1209 return std::unique_ptr<IWorkload>();
1210}
1211
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001212std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
1213 const WorkloadInfo& Info) const
1214{
1215 return std::unique_ptr<IWorkload>();
1216}
1217
James Conroyee18dc82019-07-17 11:27:46 +01001218std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
1219 const WorkloadInfo& info) const
1220{
1221 return std::unique_ptr<IWorkload>();
1222}
1223
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001224std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
1225 const WorkloadInfo& info) const
1226{
1227 return std::unique_ptr<IWorkload>();
1228}
1229
1230std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
1231 const WorkloadInfo& info) const
1232{
1233 return std::unique_ptr<IWorkload>();
1234}
1235
Teresa Charlina9075df2019-06-27 15:41:57 +01001236std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
1237 const WorkloadInfo& info) const
1238{
1239 return std::unique_ptr<IWorkload>();
1240}
1241
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001242std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
1243 const WorkloadInfo& info) const
1244{
1245 return std::unique_ptr<IWorkload>();
1246}
1247
1248std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
1249 const WorkloadInfo& info) const
1250{
1251 return std::unique_ptr<IWorkload>();
1252}
1253
1254std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
1255 const WorkloadInfo& info) const
1256{
1257 return std::unique_ptr<IWorkload>();
1258}
1259
1260std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
1261 const WorkloadInfo& info) const
1262{
1263 return std::unique_ptr<IWorkload>();
1264}
1265
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001266std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
1267 const WorkloadInfo& info) const
1268{
1269 return std::unique_ptr<IWorkload>();
1270}
1271
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001272std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor,
1273 const WorkloadInfo& info) const
1274{
1275 return std::unique_ptr<IWorkload>();
1276}
1277
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001278std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
1279 const WorkloadInfo& Info) const
1280{
1281 return std::unique_ptr<IWorkload>();
1282}
1283
1284std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
1285 const WorkloadInfo& info) const
1286{
1287 return std::unique_ptr<IWorkload>();
1288}
1289
Sadik Armaganeff363d2019-04-05 15:25:46 +01001290std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
1291 const WorkloadInfo& info) const
1292{
1293 return std::unique_ptr<IWorkload>();
1294}
1295
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001296std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
1297 const TransposeConvolution2dQueueDescriptor& descriptor,
1298 const WorkloadInfo& info) const
1299{
1300 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001301}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001302
1303} // namepsace armnn