blob: 1c23e1774b50a2026eed7e4702aa7989a0a6abc1 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
Derek Lambertia9cca6a2019-03-25 15:41:58 +00007#include "WorkloadFactory.hpp"
8
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Layer.hpp>
11#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010012
David Beckb4540be2018-09-24 13:18:27 +010013#include <armnn/Types.hpp>
14#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000015#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000018#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000019#include <backendsCommon/IBackendInternal.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000023#include <boost/iterator/transform_iterator.hpp>
24
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000025#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000026#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000027
telsoa014fcda012018-03-09 14:13:49 +000028namespace armnn
29{
30
telsoa01c577f2c2018-08-31 09:22:23 +010031namespace
32{
telsoa01c577f2c2018-08-31 09:22:23 +010033
David Beck29c75de2018-10-23 13:35:58 +010034const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
35{
36 if (!type)
37 {
38 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010039 }
40
David Beck29c75de2018-10-23 13:35:58 +010041 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010042}
43
David Beck29c75de2018-10-23 13:35:58 +010044} // anonymous namespace
45
David Beck33f0ae02018-10-18 15:13:56 +010046bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010047 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010048 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010049 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000050{
David Beck33f0ae02018-10-18 15:13:56 +010051 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000052 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010053 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
54
David Beck111b5d92018-11-12 14:59:37 +000055 auto const& backendRegistry = BackendRegistryInstance();
56 if (!backendRegistry.IsBackendRegistered(backendId))
57 {
58 std::stringstream ss;
59 ss << connectableLayer.GetName() << " is not supported on " << backendId
60 << " because this backend is not registered.";
61
62 outReasonIfUnsupported = ss.str();
63 return false;
64 }
65
66 auto backendFactory = backendRegistry.GetFactory(backendId);
67 auto backendObject = backendFactory();
68 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010069
telsoa014fcda012018-03-09 14:13:49 +000070 switch(layer.GetType())
71 {
72 case LayerType::Activation:
73 {
74 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
75 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010076 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010077 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010078 OverrideDataType(input, dataType),
79 OverrideDataType(output, dataType),
80 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010081 reason);
telsoa014fcda012018-03-09 14:13:49 +000082 break;
83 }
84 case LayerType::Addition:
85 {
86 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
87 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
88 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010089 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010090 OverrideDataType(input0, dataType),
91 OverrideDataType(input1, dataType),
92 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010093 reason);
telsoa014fcda012018-03-09 14:13:49 +000094 break;
95 }
96 case LayerType::BatchNormalization:
97 {
98 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
99 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100100 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
101 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
102 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
103 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
104 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100105 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100106 OverrideDataType(input, dataType),
107 OverrideDataType(output, dataType),
108 OverrideDataType(mean, dataType),
109 OverrideDataType(var, dataType),
110 OverrideDataType(beta, dataType),
111 OverrideDataType(gamma, dataType),
112 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100113 reason);
telsoa014fcda012018-03-09 14:13:49 +0000114 break;
115 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000116 case LayerType::BatchToSpaceNd:
117 {
118 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
119 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
120 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
121
122 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
123 OverrideDataType(output, dataType),
124 cLayer->GetParameters(),
125 reason);
126 break;
127 }
telsoa014fcda012018-03-09 14:13:49 +0000128 case LayerType::Constant:
129 {
130 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100131 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100132 break;
133 }
134 case LayerType::ConvertFp16ToFp32:
135 {
136 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
137 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100138 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100139 break;
140 }
141 case LayerType::ConvertFp32ToFp16:
142 {
143 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
144 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100145 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000146 break;
147 }
148 case LayerType::Convolution2d:
149 {
150 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100151
152 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
153 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100154 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100155 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
156
arovir01a6824102018-08-28 17:40:45 +0100157 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100158
arovir01a6824102018-08-28 17:40:45 +0100159 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100160 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100161 if (descriptor.m_BiasEnabled)
162 {
David Beck5eec11d2018-10-04 15:43:17 +0100163 biases =
164 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100165 }
166
David Beck33f0ae02018-10-18 15:13:56 +0100167 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100168 input,
169 output,
170 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100171 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100172 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100173 reason);
telsoa014fcda012018-03-09 14:13:49 +0000174 break;
175 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000176 case LayerType::Debug:
177 {
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
180
181 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
182 OverrideDataType(output, dataType),
183 reason);
184 break;
185 }
telsoa014fcda012018-03-09 14:13:49 +0000186 case LayerType::DepthwiseConvolution2d:
187 {
188 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
190 dataType);
191 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
192 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
193
telsoa01c577f2c2018-08-31 09:22:23 +0100194 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100195
196 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100197 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100198 if (descriptor.m_BiasEnabled)
199 {
David Beck5eec11d2018-10-04 15:43:17 +0100200 biases =
201 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100202 }
telsoa01c577f2c2018-08-31 09:22:23 +0100203
David Beck33f0ae02018-10-18 15:13:56 +0100204 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100205 input,
206 output,
207 descriptor,
208 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100209 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100210 reason);
telsoa014fcda012018-03-09 14:13:49 +0000211 break;
212 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000213 case LayerType::Dequantize:
214 {
215 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
216 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
217
218 result = layerSupportObject->IsDequantizeSupported(OverrideDataType(input, dataType),
219 OverrideDataType(output, DataType::Float32),
220 reason);
221 break;
222 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000223 case LayerType::DetectionPostProcess:
224 {
225 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
226 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
227 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
228 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
229 result = layerSupportObject->IsDetectionPostProcessSupported(input0,
230 input1,
231 descriptor,
232 reason);
233 break;
234 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000235 case LayerType::Equal:
236 {
237 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
238 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
239 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
240 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
241 OverrideDataType(input1, dataType),
242 OverrideDataType(output, dataType),
243 reason);
244 break;
245 }
telsoa014fcda012018-03-09 14:13:49 +0000246 case LayerType::FakeQuantization:
247 {
248 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
249 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100250 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
251 cLayer->GetParameters(),
252 reason);
telsoa014fcda012018-03-09 14:13:49 +0000253 break;
254 }
255 case LayerType::Floor:
256 {
257 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
258 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100259 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
260 OverrideDataType(output, dataType),
261 reason);
telsoa014fcda012018-03-09 14:13:49 +0000262 break;
263 }
264 case LayerType::FullyConnected:
265 {
266 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
267 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100268 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
269 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
270
271 TensorInfo biasInfo;
272 const TensorInfo * biasInfoPtr = nullptr;
273 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
274 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
275 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
276
277 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
278 if (descriptor.m_BiasEnabled)
279 {
280 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
281 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
282 biasInfoPtr = &biasInfo;
283 }
284 else
285 {
286 // If biases are not enabled pass a dummy tensorinfo for the validation
287 switch(input.GetDataType())
288 {
289 case DataType::Float16:
290 {
291 biasInfoPtr = &dummyFloat16Bias;
292 break;
293 }
294 case DataType::Float32:
295 {
296 biasInfoPtr = &dummyFloat32Bias;
297 break;
298 }
299 case DataType::QuantisedAsymm8:
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100300 case DataType::QuantisedSymm16:
telsoa01c577f2c2018-08-31 09:22:23 +0100301 {
302 biasInfoPtr = &dummyQA8Bias;
303 break;
304 }
305 default:
306 {
307 BOOST_ASSERT_MSG(false, "Unexpected bias type");
308 }
309 }
310 }
311
David Beck33f0ae02018-10-18 15:13:56 +0100312 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100313 OverrideDataType(input, dataType),
314 OverrideDataType(output, dataType),
315 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
316 *biasInfoPtr,
317 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100318 reason);
telsoa014fcda012018-03-09 14:13:49 +0000319 break;
320 }
narpra01b89b05f2019-01-16 09:53:09 +0000321 case LayerType::Gather:
322 {
323 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
324 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
325 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
326 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100327 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000328 OverrideDataType(output, dataType),
329 reason);
330 break;
331 }
telsoa014fcda012018-03-09 14:13:49 +0000332 case LayerType::Input:
333 {
334 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100335 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000336 break;
337 }
338 case LayerType::L2Normalization:
339 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100340 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
341 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
342
telsoa014fcda012018-03-09 14:13:49 +0000343 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100344 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100345
David Beck33f0ae02018-10-18 15:13:56 +0100346 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100347 OverrideDataType(input, dataType),
348 OverrideDataType(output, dataType),
349 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100350 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100351 break;
352 }
353 case LayerType::Lstm:
354 {
355 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
356 const LstmDescriptor& descriptor = cLayer->GetParameters();
357
358 // All inputs.
359 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
360 dataType);
361 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
362 dataType);
363 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
364 dataType);
365 // All outputs
366 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
367 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
368 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
369 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
370
371 // Basic parameters
372 const TensorInfo& inputToForgetWeights
373 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
374 const TensorInfo& inputToCellWeights
375 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
376 const TensorInfo& inputToOutputWeights
377 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
378 const TensorInfo& recurrentToForgetWeights
379 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
380 const TensorInfo& recurrentToCellWeights
381 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
382 const TensorInfo& recurrentToOutputWeights
383 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
384 const TensorInfo& forgetGateBias
385 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
386 const TensorInfo& cellBias
387 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
388 const TensorInfo& outputGateBias
389 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
390
Jan Eilersd01a83c2019-07-03 18:20:40 +0100391 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100392
Jan Eilersd01a83c2019-07-03 18:20:40 +0100393 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
394 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
395 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
396 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
397 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
398 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
399 paramsInfo.m_ForgetGateBias = &forgetGateBias;
400 paramsInfo.m_CellBias = &cellBias;
401 paramsInfo.m_OutputGateBias = &outputGateBias;
402
403
404 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100405 TensorInfo optInputToInputWeights;
406 TensorInfo optRecurrentToInputWeights;
407 TensorInfo optCellToInputWeights;
408 TensorInfo optInputGateBias;
409 TensorInfo optProjectionWeights;
410 TensorInfo optProjectionBias;
411 TensorInfo optCellToForgetWeights;
412 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100413 TensorInfo optInputLayerNormWeights;
414 TensorInfo optForgetLayerNormWeights;
415 TensorInfo optCellLayerNormWeights;
416 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100417
418 if(!descriptor.m_CifgEnabled)
419 {
420 optInputToInputWeights =
421 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100422 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100423
424 optRecurrentToInputWeights =
425 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100426 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100427 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
428 {
429 optCellToInputWeights =
430 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100431 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100432 }
433 optInputGateBias =
434 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100435 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100436 }
437
438 if(descriptor.m_ProjectionEnabled)
439 {
440 optProjectionWeights =
441 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100442 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100443 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
444 {
445 optProjectionBias =
446 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100447 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100448 }
449 }
450
451 if(descriptor.m_PeepholeEnabled)
452 {
453 optCellToForgetWeights =
454 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100455 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100456 optCellToOutputWeights =
457 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100458 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100459 }
460
Jan Eilers38e05bd2019-06-26 13:10:09 +0100461 if(descriptor.m_LayerNormEnabled)
462 {
463 optInputLayerNormWeights = OverrideDataType(
464 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100465 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100466
467 optForgetLayerNormWeights = OverrideDataType(
468 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100469 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100470
471 optCellLayerNormWeights = OverrideDataType(
472 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100473 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100474
475 optOutputLayerNormWeights = OverrideDataType(
476 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100477 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100478 }
479
David Beck33f0ae02018-10-18 15:13:56 +0100480 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100481 input,
482 outputStateIn,
483 cellStateIn,
484 scratchBuffer,
485 outputStateOut,
486 cellStateOut,
487 output,
488 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100489 paramsInfo,
490 reason);
telsoa014fcda012018-03-09 14:13:49 +0000491 break;
492 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000493 case LayerType::Maximum:
494 {
495 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
496 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
497 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
498
499 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
500 OverrideDataType(input1, dataType),
501 OverrideDataType(output, dataType),
502 reason);
503 break;
504 }
narpra01b89b05f2019-01-16 09:53:09 +0000505 case LayerType::MemCopy:
506 {
507 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
508 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000509
narpra01b89b05f2019-01-16 09:53:09 +0000510 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
511 OverrideDataType(output, dataType),
512 reason);
513 break;
514 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100515 case LayerType::Merge:
516 {
517 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
518 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
519 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
520
521 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
522 OverrideDataType(input1, dataType),
523 OverrideDataType(output, dataType),
524 reason);
525 break;
526 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100527 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000528 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100529 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000530
telsoa01c577f2c2018-08-31 09:22:23 +0100531 // Get vector of all inputs.
532 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000533 {
telsoa01c577f2c2018-08-31 09:22:23 +0100534 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000535 };
telsoa01c577f2c2018-08-31 09:22:23 +0100536 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
537 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
538 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000539
telsoa01c577f2c2018-08-31 09:22:23 +0100540 auto getTensorInfoPtr = [](const TensorInfo& info)
541 {
542 return &info;
543 };
544 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
545 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
546 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000547
Nikhil Raj8599a412018-11-19 14:51:07 +0000548 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
549
Jim Flynne242f2d2019-05-22 14:24:13 +0100550 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
551
552
telsoa014fcda012018-03-09 14:13:49 +0000553 break;
554 }
555 case LayerType::Multiplication:
556 {
557 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
558 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100559 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100560 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100561 OverrideDataType(input0, dataType),
562 OverrideDataType(input1, dataType),
563 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100564 reason);
telsoa014fcda012018-03-09 14:13:49 +0000565 break;
566 }
567 case LayerType::Normalization:
568 {
569 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
570 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
571 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100572 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
573 OverrideDataType(output, dataType),
574 cLayer->GetParameters(),
575 reason);
telsoa014fcda012018-03-09 14:13:49 +0000576 break;
577 }
578 case LayerType::Output:
579 {
580 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100581 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000582 break;
583 }
584 case LayerType::Permute:
585 {
586 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
587 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
588 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100589 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
590 OverrideDataType(output, dataType),
591 cLayer->GetParameters(),
592 reason);
telsoa014fcda012018-03-09 14:13:49 +0000593 break;
594 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100595 case LayerType::Pad:
596 {
597 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
598 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
599 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100600 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100601 OverrideDataType(input, dataType),
602 OverrideDataType(output, dataType),
603 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100604 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100605 break;
606 }
telsoa014fcda012018-03-09 14:13:49 +0000607 case LayerType::Pooling2d:
608 {
609 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
610 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
611 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100612 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
613 OverrideDataType(output, dataType),
614 cLayer->GetParameters(),
615 reason);
telsoa014fcda012018-03-09 14:13:49 +0000616 break;
617 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000618 case LayerType::PreCompiled:
619 {
620 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
621 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
622 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
623 cLayer->GetParameters(),
624 reason);
625 break;
626 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000627 case LayerType::Quantize:
628 {
629 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
630 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
631 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
632 break;
633 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100634 case LayerType::Division:
635 {
636 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
637 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
638 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100639 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100640 OverrideDataType(input0, dataType),
641 OverrideDataType(input1, dataType),
642 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100643 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100644 break;
645 }
telsoa014fcda012018-03-09 14:13:49 +0000646 case LayerType::Reshape:
647 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000648 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000649 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000650 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
651 cLayer->GetParameters(),
652 reason);
telsoa014fcda012018-03-09 14:13:49 +0000653 break;
654 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100655 case LayerType::Resize:
656 {
657 auto cLayer = boost::polymorphic_downcast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100658 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100659 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
660 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
661 OverrideDataType(output, dataType),
662 cLayer->GetParameters(),
663 reason);
664 break;
665 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000666 case LayerType::Rsqrt:
667 {
668 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
669 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
670 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
671 OverrideDataType(output, dataType),
672 reason);
673 break;
674 }
telsoa014fcda012018-03-09 14:13:49 +0000675 case LayerType::Softmax:
676 {
677 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
678 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100679 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100680 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
681 OverrideDataType(output, dataType),
682 cLayer->GetParameters(),
683 reason);
telsoa014fcda012018-03-09 14:13:49 +0000684 break;
685 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000686 case LayerType::SpaceToBatchNd:
687 {
688 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
689 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
690 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
691 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
692 OverrideDataType(output, dataType),
693 cLayer->GetParameters(),
694 reason);
695 break;
696 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100697 case LayerType::SpaceToDepth:
698 {
699 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
700
701 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
702 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
703
704 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
705 OverrideDataType(output, dataType),
706 cLayer->GetParameters(),
707 reason);
708 break;
709 }
telsoa014fcda012018-03-09 14:13:49 +0000710 case LayerType::Splitter:
711 {
712 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
713 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100714
715 // Get vector of all outputs.
716 auto getTensorInfo = [&dataType](const OutputSlot& slot)
717 {
718 return OverrideDataType(slot.GetTensorInfo(), dataType);
719 };
720 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
721 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
722 std::vector<TensorInfo> outputs(beginI, endI);
723
724 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
725
David Beck33f0ae02018-10-18 15:13:56 +0100726 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100727 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100728 cLayer->GetParameters(),
729 reason);
telsoa014fcda012018-03-09 14:13:49 +0000730 break;
731 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000732 case LayerType::StridedSlice:
733 {
734 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
735 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
736 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
737 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
738 OverrideDataType(output, dataType),
739 cLayer->GetParameters(),
740 reason);
741 break;
742 }
David Beckc2044fe2018-09-05 15:00:38 +0100743 case LayerType::Subtraction:
744 {
745 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
746 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
747 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100748 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100749 OverrideDataType(input0, dataType),
750 OverrideDataType(input1, dataType),
751 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100752 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100753 break;
754 }
Sadik Armaganeff363d2019-04-05 15:25:46 +0100755 case LayerType::Switch:
756 {
757 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
758 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
759 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
760 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
761 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
762 OverrideDataType(input1, dataType),
763 OverrideDataType(output0, dataType),
764 OverrideDataType(output1, dataType),
765 reason);
766 break;
767 }
narpra0132b90462018-09-13 11:07:48 +0100768 case LayerType::Mean:
769 {
770 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
771 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
772 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100773 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100774 OverrideDataType(input, dataType),
775 OverrideDataType(output, dataType),
776 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100777 reason);
narpra0132b90462018-09-13 11:07:48 +0100778 break;
779 }
kevmay0190539692018-11-29 08:40:19 +0000780 case LayerType::Minimum:
781 {
782 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
783 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
784 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
785 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
786 OverrideDataType(input1, dataType),
787 OverrideDataType(output, dataType),
788 reason);
789 break;
790 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000791 case LayerType::Greater:
792 {
793 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
794 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
795 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000796 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
797 OverrideDataType(input1, dataType),
798 OverrideDataType(output, DataType::Boolean),
799 reason);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000800 break;
801 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +0100802 case LayerType::Prelu:
803 {
804 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
805 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
806 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
807 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
808 OverrideDataType(alpha, dataType),
809 OverrideDataType(output, dataType),
810 reason);
811 break;
812 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +0100813 case LayerType::TransposeConvolution2d:
814 {
815 auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
816
817 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
818 dataType);
819 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
820
821 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
822
823 Optional<TensorInfo> biases;
824 if (descriptor.m_BiasEnabled)
825 {
826 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
827 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
828 GetBiasTypeFromWeightsType(dataType));
829 }
830
831 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
832 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
833
834 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
835 output,
836 descriptor,
837 weights,
838 biases,
839 reason);
840
841 break;
842 }
telsoa014fcda012018-03-09 14:13:49 +0000843 default:
844 {
845 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100846 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000847 result = false;
848 break;
849 }
850 }
telsoa014fcda012018-03-09 14:13:49 +0000851 return result;
852}
853
David Beckdcb751f2018-10-03 11:42:42 +0100854bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100855 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100856 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000857{
David Beckdcb751f2018-10-03 11:42:42 +0100858 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100859 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000860}
861
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000862// Default Implementations
863std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
864 const WorkloadInfo& info) const
865{
866 return std::unique_ptr<IWorkload>();
867}
868
869std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
870 const WorkloadInfo& info) const
871{
872 return std::unique_ptr<IWorkload>();
873}
874
875std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
876 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
877{
878 return std::unique_ptr<IWorkload>();
879}
880
881std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
882 const WorkloadInfo& Info) const
883{
884 return std::unique_ptr<IWorkload>();
885}
886
Jim Flynne242f2d2019-05-22 14:24:13 +0100887std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
Jim Flynn4ed6c832019-05-20 11:02:46 +0100888 const WorkloadInfo& info) const
889{
890 return std::unique_ptr<IWorkload>();
891}
892
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000893std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
894 const WorkloadInfo& info) const
895{
896 return std::unique_ptr<IWorkload>();
897}
898
899std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
900 const WorkloadInfo& info) const
901{
902 return std::unique_ptr<IWorkload>();
903}
904
905std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
906 const WorkloadInfo& info) const
907{
908 return std::unique_ptr<IWorkload>();
909}
910
911std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
912 const WorkloadInfo& info) const
913{
914 return std::unique_ptr<IWorkload>();
915}
916
917std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
918 const WorkloadInfo& info) const
919{
920 return std::unique_ptr<IWorkload>();
921}
922
923std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
924 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
925{
926 return std::unique_ptr<IWorkload>();
927}
928
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000929std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
930 const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const
931{
932 return std::unique_ptr<IWorkload>();
933}
934
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000935std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
936 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const
937{
938 return std::unique_ptr<IWorkload>();
939}
940
941std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
942 const WorkloadInfo& info) const
943{
944 return std::unique_ptr<IWorkload>();
945}
946
947std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
948 const WorkloadInfo& Info) const
949{
950 return std::unique_ptr<IWorkload>();
951}
952
953std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
954 const WorkloadInfo& info) const
955{
956 return std::unique_ptr<IWorkload>();
957}
958
959std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
960 const WorkloadInfo& info) const
961{
962 return std::unique_ptr<IWorkload>();
963}
964
965std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
966 const WorkloadInfo& info) const
967{
968 return std::unique_ptr<IWorkload>();
969}
970
971std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
972 const WorkloadInfo& info) const
973{
974 return std::unique_ptr<IWorkload>();
975}
976
977std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
978 const WorkloadInfo& info) const
979{
980 return std::unique_ptr<IWorkload>();
981}
982
983std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
984 const WorkloadInfo& info) const
985{
986 return std::unique_ptr<IWorkload>();
987}
988
989std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
990 const WorkloadInfo& info) const
991{
992 return std::unique_ptr<IWorkload>();
993}
994
995std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
996 const WorkloadInfo& info) const
997{
998 return std::unique_ptr<IWorkload>();
999}
1000
1001std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
1002 const WorkloadInfo& Info) const
1003{
1004 return std::unique_ptr<IWorkload>();
1005}
1006
1007std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
1008 const WorkloadInfo& info) const
1009{
1010 return std::unique_ptr<IWorkload>();
1011}
1012
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001013std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
1014 const WorkloadInfo& info) const
1015{
1016 return std::unique_ptr<IWorkload>();
1017}
1018
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001019std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
1020 const WorkloadInfo& info) const
1021{
1022 return std::unique_ptr<IWorkload>();
1023}
1024
1025std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
1026 const WorkloadInfo& info) const
1027{
1028 return std::unique_ptr<IWorkload>();
1029}
1030
1031std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
1032 const WorkloadInfo& info) const
1033{
1034 return std::unique_ptr<IWorkload>();
1035}
1036
1037std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
1038 const WorkloadInfo& info) const
1039{
1040 return std::unique_ptr<IWorkload>();
1041}
1042
1043std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
1044 const WorkloadInfo& info) const
1045{
1046 return std::unique_ptr<IWorkload>();
1047}
1048
1049std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
1050 const WorkloadInfo& Info) const
1051{
1052 return std::unique_ptr<IWorkload>();
1053}
1054
1055std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
1056 const WorkloadInfo& info) const
1057{
1058 return std::unique_ptr<IWorkload>();
1059}
1060
1061std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
1062 const WorkloadInfo& info) const
1063{
1064 return std::unique_ptr<IWorkload>();
1065}
1066
1067std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
1068 const WorkloadInfo& info) const
1069{
1070 return std::unique_ptr<IWorkload>();
1071}
1072
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001073std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
1074 const WorkloadInfo &info) const
1075{
1076 return std::unique_ptr<IWorkload>();
1077}
1078
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001079std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
1080 const WorkloadInfo& Info) const
1081{
1082 return std::unique_ptr<IWorkload>();
1083}
1084
1085std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
1086 const WorkloadInfo& info) const
1087{
1088 return std::unique_ptr<IWorkload>();
1089}
1090
1091std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
1092 const WorkloadInfo& info) const
1093{
1094 return std::unique_ptr<IWorkload>();
1095}
1096
Teresa Charlina9075df2019-06-27 15:41:57 +01001097std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
1098 const WorkloadInfo& info) const
1099{
1100 return std::unique_ptr<IWorkload>();
1101}
1102
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001103std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
1104 const WorkloadInfo& info) const
1105{
1106 return std::unique_ptr<IWorkload>();
1107}
1108
1109std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
1110 const WorkloadInfo& info) const
1111{
1112 return std::unique_ptr<IWorkload>();
1113}
1114
1115std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
1116 const WorkloadInfo& info) const
1117{
1118 return std::unique_ptr<IWorkload>();
1119}
1120
1121std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
1122 const WorkloadInfo& info) const
1123{
1124 return std::unique_ptr<IWorkload>();
1125}
1126
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001127std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
1128 const WorkloadInfo& info) const
1129{
1130 return std::unique_ptr<IWorkload>();
1131}
1132
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001133std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
1134 const WorkloadInfo& Info) const
1135{
1136 return std::unique_ptr<IWorkload>();
1137}
1138
1139std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
1140 const WorkloadInfo& info) const
1141{
1142 return std::unique_ptr<IWorkload>();
1143}
1144
Sadik Armaganeff363d2019-04-05 15:25:46 +01001145std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
1146 const WorkloadInfo& info) const
1147{
1148 return std::unique_ptr<IWorkload>();
1149}
1150
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001151std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
1152 const TransposeConvolution2dQueueDescriptor& descriptor,
1153 const WorkloadInfo& info) const
1154{
1155 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001156}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001157
1158} // namepsace armnn