blob: 3502c381e8cfebc2f45a2bcc472132e2a06ebde6 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
Derek Lambertia9cca6a2019-03-25 15:41:58 +00007#include "WorkloadFactory.hpp"
8
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Layer.hpp>
11#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010012
David Beckb4540be2018-09-24 13:18:27 +010013#include <armnn/Types.hpp>
14#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000015#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000018#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000019#include <backendsCommon/IBackendInternal.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000023#include <boost/iterator/transform_iterator.hpp>
24
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000025#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000026#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000027
telsoa014fcda012018-03-09 14:13:49 +000028namespace armnn
29{
30
telsoa01c577f2c2018-08-31 09:22:23 +010031namespace
32{
telsoa01c577f2c2018-08-31 09:22:23 +010033
David Beck29c75de2018-10-23 13:35:58 +010034const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
35{
36 if (!type)
37 {
38 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010039 }
40
David Beck29c75de2018-10-23 13:35:58 +010041 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010042}
43
David Beck29c75de2018-10-23 13:35:58 +010044} // anonymous namespace
45
David Beck33f0ae02018-10-18 15:13:56 +010046bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010047 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010048 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010049 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000050{
David Beck33f0ae02018-10-18 15:13:56 +010051 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000052 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010053 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
54
David Beck111b5d92018-11-12 14:59:37 +000055 auto const& backendRegistry = BackendRegistryInstance();
56 if (!backendRegistry.IsBackendRegistered(backendId))
57 {
58 std::stringstream ss;
59 ss << connectableLayer.GetName() << " is not supported on " << backendId
60 << " because this backend is not registered.";
61
62 outReasonIfUnsupported = ss.str();
63 return false;
64 }
65
66 auto backendFactory = backendRegistry.GetFactory(backendId);
67 auto backendObject = backendFactory();
68 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010069
telsoa014fcda012018-03-09 14:13:49 +000070 switch(layer.GetType())
71 {
72 case LayerType::Activation:
73 {
74 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
75 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010076 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010077 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010078 OverrideDataType(input, dataType),
79 OverrideDataType(output, dataType),
80 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010081 reason);
telsoa014fcda012018-03-09 14:13:49 +000082 break;
83 }
84 case LayerType::Addition:
85 {
86 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
87 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
88 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010089 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010090 OverrideDataType(input0, dataType),
91 OverrideDataType(input1, dataType),
92 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010093 reason);
telsoa014fcda012018-03-09 14:13:49 +000094 break;
95 }
96 case LayerType::BatchNormalization:
97 {
98 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
99 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100100 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
101 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
102 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
103 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
104 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100105 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100106 OverrideDataType(input, dataType),
107 OverrideDataType(output, dataType),
108 OverrideDataType(mean, dataType),
109 OverrideDataType(var, dataType),
110 OverrideDataType(beta, dataType),
111 OverrideDataType(gamma, dataType),
112 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100113 reason);
telsoa014fcda012018-03-09 14:13:49 +0000114 break;
115 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000116 case LayerType::BatchToSpaceNd:
117 {
118 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
119 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
120 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
121
122 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
123 OverrideDataType(output, dataType),
124 cLayer->GetParameters(),
125 reason);
126 break;
127 }
telsoa014fcda012018-03-09 14:13:49 +0000128 case LayerType::Constant:
129 {
130 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100131 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100132 break;
133 }
134 case LayerType::ConvertFp16ToFp32:
135 {
136 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
137 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100138 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100139 break;
140 }
141 case LayerType::ConvertFp32ToFp16:
142 {
143 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
144 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100145 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000146 break;
147 }
148 case LayerType::Convolution2d:
149 {
150 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100151
152 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
153 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100154 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100155 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
156
arovir01a6824102018-08-28 17:40:45 +0100157 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100158
arovir01a6824102018-08-28 17:40:45 +0100159 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100160 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100161 if (descriptor.m_BiasEnabled)
162 {
David Beck5eec11d2018-10-04 15:43:17 +0100163 biases =
164 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100165 }
166
David Beck33f0ae02018-10-18 15:13:56 +0100167 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100168 input,
169 output,
170 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100171 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100172 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100173 reason);
telsoa014fcda012018-03-09 14:13:49 +0000174 break;
175 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000176 case LayerType::Debug:
177 {
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
180
181 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
182 OverrideDataType(output, dataType),
183 reason);
184 break;
185 }
telsoa014fcda012018-03-09 14:13:49 +0000186 case LayerType::DepthwiseConvolution2d:
187 {
188 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
190 dataType);
191 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
192 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
193
telsoa01c577f2c2018-08-31 09:22:23 +0100194 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100195
196 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100197 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100198 if (descriptor.m_BiasEnabled)
199 {
David Beck5eec11d2018-10-04 15:43:17 +0100200 biases =
201 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100202 }
telsoa01c577f2c2018-08-31 09:22:23 +0100203
David Beck33f0ae02018-10-18 15:13:56 +0100204 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100205 input,
206 output,
207 descriptor,
208 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100209 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100210 reason);
telsoa014fcda012018-03-09 14:13:49 +0000211 break;
212 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000213 case LayerType::Dequantize:
214 {
215 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
216 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
217
218 result = layerSupportObject->IsDequantizeSupported(OverrideDataType(input, dataType),
219 OverrideDataType(output, DataType::Float32),
220 reason);
221 break;
222 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000223 case LayerType::DetectionPostProcess:
224 {
225 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
226 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
227 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
228 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
229 result = layerSupportObject->IsDetectionPostProcessSupported(input0,
230 input1,
231 descriptor,
232 reason);
233 break;
234 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000235 case LayerType::Equal:
236 {
237 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
238 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
239 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
240 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
241 OverrideDataType(input1, dataType),
242 OverrideDataType(output, dataType),
243 reason);
244 break;
245 }
telsoa014fcda012018-03-09 14:13:49 +0000246 case LayerType::FakeQuantization:
247 {
248 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
249 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100250 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
251 cLayer->GetParameters(),
252 reason);
telsoa014fcda012018-03-09 14:13:49 +0000253 break;
254 }
255 case LayerType::Floor:
256 {
257 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
258 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100259 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
260 OverrideDataType(output, dataType),
261 reason);
telsoa014fcda012018-03-09 14:13:49 +0000262 break;
263 }
264 case LayerType::FullyConnected:
265 {
266 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
267 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100268 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
269 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
270
271 TensorInfo biasInfo;
272 const TensorInfo * biasInfoPtr = nullptr;
273 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
274 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
275 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
276
277 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
278 if (descriptor.m_BiasEnabled)
279 {
280 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
281 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
282 biasInfoPtr = &biasInfo;
283 }
284 else
285 {
286 // If biases are not enabled pass a dummy tensorinfo for the validation
287 switch(input.GetDataType())
288 {
289 case DataType::Float16:
290 {
291 biasInfoPtr = &dummyFloat16Bias;
292 break;
293 }
294 case DataType::Float32:
295 {
296 biasInfoPtr = &dummyFloat32Bias;
297 break;
298 }
299 case DataType::QuantisedAsymm8:
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100300 case DataType::QuantisedSymm16:
telsoa01c577f2c2018-08-31 09:22:23 +0100301 {
302 biasInfoPtr = &dummyQA8Bias;
303 break;
304 }
305 default:
306 {
307 BOOST_ASSERT_MSG(false, "Unexpected bias type");
308 }
309 }
310 }
311
David Beck33f0ae02018-10-18 15:13:56 +0100312 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100313 OverrideDataType(input, dataType),
314 OverrideDataType(output, dataType),
315 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
316 *biasInfoPtr,
317 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100318 reason);
telsoa014fcda012018-03-09 14:13:49 +0000319 break;
320 }
narpra01b89b05f2019-01-16 09:53:09 +0000321 case LayerType::Gather:
322 {
323 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
324 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
325 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
326 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100327 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000328 OverrideDataType(output, dataType),
329 reason);
330 break;
331 }
telsoa014fcda012018-03-09 14:13:49 +0000332 case LayerType::Input:
333 {
334 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100335 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000336 break;
337 }
338 case LayerType::L2Normalization:
339 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100340 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
341 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
342
telsoa014fcda012018-03-09 14:13:49 +0000343 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100344 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100345
David Beck33f0ae02018-10-18 15:13:56 +0100346 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100347 OverrideDataType(input, dataType),
348 OverrideDataType(output, dataType),
349 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100350 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100351 break;
352 }
353 case LayerType::Lstm:
354 {
355 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
356 const LstmDescriptor& descriptor = cLayer->GetParameters();
357
358 // All inputs.
359 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
360 dataType);
361 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
362 dataType);
363 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
364 dataType);
365 // All outputs
366 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
367 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
368 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
369 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
370
371 // Basic parameters
372 const TensorInfo& inputToForgetWeights
373 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
374 const TensorInfo& inputToCellWeights
375 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
376 const TensorInfo& inputToOutputWeights
377 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
378 const TensorInfo& recurrentToForgetWeights
379 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
380 const TensorInfo& recurrentToCellWeights
381 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
382 const TensorInfo& recurrentToOutputWeights
383 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
384 const TensorInfo& forgetGateBias
385 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
386 const TensorInfo& cellBias
387 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
388 const TensorInfo& outputGateBias
389 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
390
391 // Optional parameters
392 const TensorInfo* inputToInputWeights = nullptr;
393 const TensorInfo* recurrentToInputWeights = nullptr;
394 const TensorInfo* cellToInputWeights = nullptr;
395 const TensorInfo* inputGateBias = nullptr;
396 const TensorInfo* projectionWeights = nullptr;
397 const TensorInfo* projectionBias = nullptr;
398 const TensorInfo* cellToForgetWeights = nullptr;
399 const TensorInfo* cellToOutputWeights = nullptr;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100400 const TensorInfo* inputLayerNormWeights = nullptr;
401 const TensorInfo* forgetLayerNormWeights = nullptr;
402 const TensorInfo* cellLayerNormWeights = nullptr;
403 const TensorInfo* outputLayerNormWeights = nullptr;
telsoa01c577f2c2018-08-31 09:22:23 +0100404
405 TensorInfo optInputToInputWeights;
406 TensorInfo optRecurrentToInputWeights;
407 TensorInfo optCellToInputWeights;
408 TensorInfo optInputGateBias;
409 TensorInfo optProjectionWeights;
410 TensorInfo optProjectionBias;
411 TensorInfo optCellToForgetWeights;
412 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100413 TensorInfo optInputLayerNormWeights;
414 TensorInfo optForgetLayerNormWeights;
415 TensorInfo optCellLayerNormWeights;
416 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100417
418 if(!descriptor.m_CifgEnabled)
419 {
420 optInputToInputWeights =
421 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
422 inputToInputWeights = &optInputToInputWeights;
423
424 optRecurrentToInputWeights =
425 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
426 recurrentToInputWeights = &optRecurrentToInputWeights;
427 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
428 {
429 optCellToInputWeights =
430 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
431 cellToInputWeights = &optCellToInputWeights;
432 }
433 optInputGateBias =
434 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
435 inputGateBias = &optInputGateBias;
436 }
437
438 if(descriptor.m_ProjectionEnabled)
439 {
440 optProjectionWeights =
441 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
442 projectionWeights = &optProjectionWeights;
443 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
444 {
445 optProjectionBias =
446 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
447 projectionBias = &optProjectionBias;
448 }
449 }
450
451 if(descriptor.m_PeepholeEnabled)
452 {
453 optCellToForgetWeights =
454 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
455 cellToForgetWeights = &optCellToForgetWeights;
456 optCellToOutputWeights =
457 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
458 cellToOutputWeights = &optCellToOutputWeights;
459 }
460
Jan Eilers38e05bd2019-06-26 13:10:09 +0100461 if(descriptor.m_LayerNormEnabled)
462 {
463 optInputLayerNormWeights = OverrideDataType(
464 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
465 inputLayerNormWeights = &optInputLayerNormWeights;
466
467 optForgetLayerNormWeights = OverrideDataType(
468 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
469 forgetLayerNormWeights = &optForgetLayerNormWeights;
470
471 optCellLayerNormWeights = OverrideDataType(
472 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
473 cellLayerNormWeights = &optCellLayerNormWeights;
474
475 optOutputLayerNormWeights = OverrideDataType(
476 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
477 outputLayerNormWeights = &optOutputLayerNormWeights;
478 }
479
David Beck33f0ae02018-10-18 15:13:56 +0100480 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100481 input,
482 outputStateIn,
483 cellStateIn,
484 scratchBuffer,
485 outputStateOut,
486 cellStateOut,
487 output,
488 descriptor,
489 inputToForgetWeights,
490 inputToCellWeights,
491 inputToOutputWeights,
492 recurrentToForgetWeights,
493 recurrentToCellWeights,
494 recurrentToOutputWeights,
495 forgetGateBias,
496 cellBias,
497 outputGateBias,
498 inputToInputWeights,
499 recurrentToInputWeights,
500 cellToInputWeights,
501 inputGateBias,
502 projectionWeights,
503 projectionBias,
504 cellToForgetWeights,
505 cellToOutputWeights,
Jan Eilers38e05bd2019-06-26 13:10:09 +0100506 reason,
507 inputLayerNormWeights,
508 forgetLayerNormWeights,
509 cellLayerNormWeights,
510 outputLayerNormWeights);
telsoa014fcda012018-03-09 14:13:49 +0000511 break;
512 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000513 case LayerType::Maximum:
514 {
515 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
516 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
517 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
518
519 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
520 OverrideDataType(input1, dataType),
521 OverrideDataType(output, dataType),
522 reason);
523 break;
524 }
narpra01b89b05f2019-01-16 09:53:09 +0000525 case LayerType::MemCopy:
526 {
527 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
528 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000529
narpra01b89b05f2019-01-16 09:53:09 +0000530 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
531 OverrideDataType(output, dataType),
532 reason);
533 break;
534 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100535 case LayerType::Merge:
536 {
537 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
538 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
539 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
540
541 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
542 OverrideDataType(input1, dataType),
543 OverrideDataType(output, dataType),
544 reason);
545 break;
546 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100547 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000548 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100549 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000550
telsoa01c577f2c2018-08-31 09:22:23 +0100551 // Get vector of all inputs.
552 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000553 {
telsoa01c577f2c2018-08-31 09:22:23 +0100554 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000555 };
telsoa01c577f2c2018-08-31 09:22:23 +0100556 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
557 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
558 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000559
telsoa01c577f2c2018-08-31 09:22:23 +0100560 auto getTensorInfoPtr = [](const TensorInfo& info)
561 {
562 return &info;
563 };
564 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
565 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
566 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000567
Nikhil Raj8599a412018-11-19 14:51:07 +0000568 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
569
Jim Flynne242f2d2019-05-22 14:24:13 +0100570 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
571
572
telsoa014fcda012018-03-09 14:13:49 +0000573 break;
574 }
575 case LayerType::Multiplication:
576 {
577 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
578 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100579 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100580 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100581 OverrideDataType(input0, dataType),
582 OverrideDataType(input1, dataType),
583 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100584 reason);
telsoa014fcda012018-03-09 14:13:49 +0000585 break;
586 }
587 case LayerType::Normalization:
588 {
589 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
590 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
591 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100592 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
593 OverrideDataType(output, dataType),
594 cLayer->GetParameters(),
595 reason);
telsoa014fcda012018-03-09 14:13:49 +0000596 break;
597 }
598 case LayerType::Output:
599 {
600 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100601 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000602 break;
603 }
604 case LayerType::Permute:
605 {
606 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
607 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
608 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100609 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
610 OverrideDataType(output, dataType),
611 cLayer->GetParameters(),
612 reason);
telsoa014fcda012018-03-09 14:13:49 +0000613 break;
614 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100615 case LayerType::Pad:
616 {
617 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
618 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
619 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100620 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100621 OverrideDataType(input, dataType),
622 OverrideDataType(output, dataType),
623 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100624 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100625 break;
626 }
telsoa014fcda012018-03-09 14:13:49 +0000627 case LayerType::Pooling2d:
628 {
629 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
630 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
631 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100632 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
633 OverrideDataType(output, dataType),
634 cLayer->GetParameters(),
635 reason);
telsoa014fcda012018-03-09 14:13:49 +0000636 break;
637 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000638 case LayerType::PreCompiled:
639 {
640 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
641 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
642 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
643 cLayer->GetParameters(),
644 reason);
645 break;
646 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000647 case LayerType::Quantize:
648 {
649 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
650 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
651 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
652 break;
653 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100654 case LayerType::Division:
655 {
656 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
657 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
658 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100659 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100660 OverrideDataType(input0, dataType),
661 OverrideDataType(input1, dataType),
662 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100663 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100664 break;
665 }
telsoa014fcda012018-03-09 14:13:49 +0000666 case LayerType::Reshape:
667 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000668 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000669 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000670 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
671 cLayer->GetParameters(),
672 reason);
telsoa014fcda012018-03-09 14:13:49 +0000673 break;
674 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100675 case LayerType::Resize:
676 {
677 auto cLayer = boost::polymorphic_downcast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100678 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100679 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
680 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
681 OverrideDataType(output, dataType),
682 cLayer->GetParameters(),
683 reason);
684 break;
685 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000686 case LayerType::Rsqrt:
687 {
688 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
689 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
690 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
691 OverrideDataType(output, dataType),
692 reason);
693 break;
694 }
telsoa014fcda012018-03-09 14:13:49 +0000695 case LayerType::Softmax:
696 {
697 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
698 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100699 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100700 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
701 OverrideDataType(output, dataType),
702 cLayer->GetParameters(),
703 reason);
telsoa014fcda012018-03-09 14:13:49 +0000704 break;
705 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000706 case LayerType::SpaceToBatchNd:
707 {
708 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
709 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
710 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
711 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
712 OverrideDataType(output, dataType),
713 cLayer->GetParameters(),
714 reason);
715 break;
716 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100717 case LayerType::SpaceToDepth:
718 {
719 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
720
721 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
722 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
723
724 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
725 OverrideDataType(output, dataType),
726 cLayer->GetParameters(),
727 reason);
728 break;
729 }
telsoa014fcda012018-03-09 14:13:49 +0000730 case LayerType::Splitter:
731 {
732 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
733 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100734
735 // Get vector of all outputs.
736 auto getTensorInfo = [&dataType](const OutputSlot& slot)
737 {
738 return OverrideDataType(slot.GetTensorInfo(), dataType);
739 };
740 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
741 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
742 std::vector<TensorInfo> outputs(beginI, endI);
743
744 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
745
David Beck33f0ae02018-10-18 15:13:56 +0100746 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100747 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100748 cLayer->GetParameters(),
749 reason);
telsoa014fcda012018-03-09 14:13:49 +0000750 break;
751 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000752 case LayerType::StridedSlice:
753 {
754 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
755 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
756 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
757 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
758 OverrideDataType(output, dataType),
759 cLayer->GetParameters(),
760 reason);
761 break;
762 }
David Beckc2044fe2018-09-05 15:00:38 +0100763 case LayerType::Subtraction:
764 {
765 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
766 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
767 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100768 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100769 OverrideDataType(input0, dataType),
770 OverrideDataType(input1, dataType),
771 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100772 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100773 break;
774 }
Sadik Armaganeff363d2019-04-05 15:25:46 +0100775 case LayerType::Switch:
776 {
777 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
778 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
779 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
780 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
781 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
782 OverrideDataType(input1, dataType),
783 OverrideDataType(output0, dataType),
784 OverrideDataType(output1, dataType),
785 reason);
786 break;
787 }
narpra0132b90462018-09-13 11:07:48 +0100788 case LayerType::Mean:
789 {
790 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
791 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
792 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100793 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100794 OverrideDataType(input, dataType),
795 OverrideDataType(output, dataType),
796 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100797 reason);
narpra0132b90462018-09-13 11:07:48 +0100798 break;
799 }
kevmay0190539692018-11-29 08:40:19 +0000800 case LayerType::Minimum:
801 {
802 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
803 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
804 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
805 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
806 OverrideDataType(input1, dataType),
807 OverrideDataType(output, dataType),
808 reason);
809 break;
810 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000811 case LayerType::Greater:
812 {
813 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
814 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
815 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000816 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
817 OverrideDataType(input1, dataType),
818 OverrideDataType(output, DataType::Boolean),
819 reason);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000820 break;
821 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +0100822 case LayerType::Prelu:
823 {
824 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
825 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
826 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
827 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
828 OverrideDataType(alpha, dataType),
829 OverrideDataType(output, dataType),
830 reason);
831 break;
832 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +0100833 case LayerType::TransposeConvolution2d:
834 {
835 auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
836
837 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
838 dataType);
839 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
840
841 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
842
843 Optional<TensorInfo> biases;
844 if (descriptor.m_BiasEnabled)
845 {
846 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
847 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
848 GetBiasTypeFromWeightsType(dataType));
849 }
850
851 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
852 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
853
854 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
855 output,
856 descriptor,
857 weights,
858 biases,
859 reason);
860
861 break;
862 }
telsoa014fcda012018-03-09 14:13:49 +0000863 default:
864 {
865 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100866 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000867 result = false;
868 break;
869 }
870 }
telsoa014fcda012018-03-09 14:13:49 +0000871 return result;
872}
873
David Beckdcb751f2018-10-03 11:42:42 +0100874bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100875 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100876 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000877{
David Beckdcb751f2018-10-03 11:42:42 +0100878 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100879 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000880}
881
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000882// Default Implementations
883std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
884 const WorkloadInfo& info) const
885{
886 return std::unique_ptr<IWorkload>();
887}
888
889std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
890 const WorkloadInfo& info) const
891{
892 return std::unique_ptr<IWorkload>();
893}
894
895std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
896 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
897{
898 return std::unique_ptr<IWorkload>();
899}
900
901std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
902 const WorkloadInfo& Info) const
903{
904 return std::unique_ptr<IWorkload>();
905}
906
Jim Flynne242f2d2019-05-22 14:24:13 +0100907std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
Jim Flynn4ed6c832019-05-20 11:02:46 +0100908 const WorkloadInfo& info) const
909{
910 return std::unique_ptr<IWorkload>();
911}
912
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000913std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
914 const WorkloadInfo& info) const
915{
916 return std::unique_ptr<IWorkload>();
917}
918
919std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
920 const WorkloadInfo& info) const
921{
922 return std::unique_ptr<IWorkload>();
923}
924
925std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
926 const WorkloadInfo& info) const
927{
928 return std::unique_ptr<IWorkload>();
929}
930
931std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
932 const WorkloadInfo& info) const
933{
934 return std::unique_ptr<IWorkload>();
935}
936
937std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
938 const WorkloadInfo& info) const
939{
940 return std::unique_ptr<IWorkload>();
941}
942
943std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
944 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
945{
946 return std::unique_ptr<IWorkload>();
947}
948
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000949std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
950 const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const
951{
952 return std::unique_ptr<IWorkload>();
953}
954
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000955std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
956 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const
957{
958 return std::unique_ptr<IWorkload>();
959}
960
961std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
962 const WorkloadInfo& info) const
963{
964 return std::unique_ptr<IWorkload>();
965}
966
967std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
968 const WorkloadInfo& Info) const
969{
970 return std::unique_ptr<IWorkload>();
971}
972
973std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
974 const WorkloadInfo& info) const
975{
976 return std::unique_ptr<IWorkload>();
977}
978
979std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
980 const WorkloadInfo& info) const
981{
982 return std::unique_ptr<IWorkload>();
983}
984
985std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
986 const WorkloadInfo& info) const
987{
988 return std::unique_ptr<IWorkload>();
989}
990
991std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
992 const WorkloadInfo& info) const
993{
994 return std::unique_ptr<IWorkload>();
995}
996
997std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
998 const WorkloadInfo& info) const
999{
1000 return std::unique_ptr<IWorkload>();
1001}
1002
1003std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
1004 const WorkloadInfo& info) const
1005{
1006 return std::unique_ptr<IWorkload>();
1007}
1008
1009std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
1010 const WorkloadInfo& info) const
1011{
1012 return std::unique_ptr<IWorkload>();
1013}
1014
1015std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
1016 const WorkloadInfo& info) const
1017{
1018 return std::unique_ptr<IWorkload>();
1019}
1020
1021std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
1022 const WorkloadInfo& Info) const
1023{
1024 return std::unique_ptr<IWorkload>();
1025}
1026
1027std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
1028 const WorkloadInfo& info) const
1029{
1030 return std::unique_ptr<IWorkload>();
1031}
1032
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001033std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
1034 const WorkloadInfo& info) const
1035{
1036 return std::unique_ptr<IWorkload>();
1037}
1038
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001039std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
1040 const WorkloadInfo& info) const
1041{
1042 return std::unique_ptr<IWorkload>();
1043}
1044
1045std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
1046 const WorkloadInfo& info) const
1047{
1048 return std::unique_ptr<IWorkload>();
1049}
1050
1051std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
1052 const WorkloadInfo& info) const
1053{
1054 return std::unique_ptr<IWorkload>();
1055}
1056
1057std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
1058 const WorkloadInfo& info) const
1059{
1060 return std::unique_ptr<IWorkload>();
1061}
1062
1063std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
1064 const WorkloadInfo& info) const
1065{
1066 return std::unique_ptr<IWorkload>();
1067}
1068
1069std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
1070 const WorkloadInfo& Info) const
1071{
1072 return std::unique_ptr<IWorkload>();
1073}
1074
1075std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
1076 const WorkloadInfo& info) const
1077{
1078 return std::unique_ptr<IWorkload>();
1079}
1080
1081std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
1082 const WorkloadInfo& info) const
1083{
1084 return std::unique_ptr<IWorkload>();
1085}
1086
1087std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
1088 const WorkloadInfo& info) const
1089{
1090 return std::unique_ptr<IWorkload>();
1091}
1092
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001093std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
1094 const WorkloadInfo &info) const
1095{
1096 return std::unique_ptr<IWorkload>();
1097}
1098
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001099std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
1100 const WorkloadInfo& Info) const
1101{
1102 return std::unique_ptr<IWorkload>();
1103}
1104
1105std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
1106 const WorkloadInfo& info) const
1107{
1108 return std::unique_ptr<IWorkload>();
1109}
1110
1111std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
1112 const WorkloadInfo& info) const
1113{
1114 return std::unique_ptr<IWorkload>();
1115}
1116
Teresa Charlina9075df2019-06-27 15:41:57 +01001117std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
1118 const WorkloadInfo& info) const
1119{
1120 return std::unique_ptr<IWorkload>();
1121}
1122
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001123std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
1124 const WorkloadInfo& info) const
1125{
1126 return std::unique_ptr<IWorkload>();
1127}
1128
1129std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
1130 const WorkloadInfo& info) const
1131{
1132 return std::unique_ptr<IWorkload>();
1133}
1134
1135std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
1136 const WorkloadInfo& info) const
1137{
1138 return std::unique_ptr<IWorkload>();
1139}
1140
1141std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
1142 const WorkloadInfo& info) const
1143{
1144 return std::unique_ptr<IWorkload>();
1145}
1146
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001147std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
1148 const WorkloadInfo& info) const
1149{
1150 return std::unique_ptr<IWorkload>();
1151}
1152
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001153std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
1154 const WorkloadInfo& Info) const
1155{
1156 return std::unique_ptr<IWorkload>();
1157}
1158
1159std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
1160 const WorkloadInfo& info) const
1161{
1162 return std::unique_ptr<IWorkload>();
1163}
1164
Sadik Armaganeff363d2019-04-05 15:25:46 +01001165std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
1166 const WorkloadInfo& info) const
1167{
1168 return std::unique_ptr<IWorkload>();
1169}
1170
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001171std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
1172 const TransposeConvolution2dQueueDescriptor& descriptor,
1173 const WorkloadInfo& info) const
1174{
1175 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001176}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001177
1178} // namepsace armnn