blob: 1aca6bfb461d9d5e97bc75a9347d3909c1891c08 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
Derek Lambertia9cca6a2019-03-25 15:41:58 +00007#include "WorkloadFactory.hpp"
8
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Layer.hpp>
11#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010012
David Beckb4540be2018-09-24 13:18:27 +010013#include <armnn/Types.hpp>
14#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000015#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000018#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000019#include <backendsCommon/IBackendInternal.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000023#include <boost/iterator/transform_iterator.hpp>
24
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000025#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000026#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000027
telsoa014fcda012018-03-09 14:13:49 +000028namespace armnn
29{
30
telsoa01c577f2c2018-08-31 09:22:23 +010031namespace
32{
telsoa01c577f2c2018-08-31 09:22:23 +010033
David Beck29c75de2018-10-23 13:35:58 +010034const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
35{
36 if (!type)
37 {
38 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010039 }
40
David Beck29c75de2018-10-23 13:35:58 +010041 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010042}
43
David Beck29c75de2018-10-23 13:35:58 +010044} // anonymous namespace
45
David Beck33f0ae02018-10-18 15:13:56 +010046bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010047 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010048 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010049 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000050{
David Beck33f0ae02018-10-18 15:13:56 +010051 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000052 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010053 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
54
David Beck111b5d92018-11-12 14:59:37 +000055 auto const& backendRegistry = BackendRegistryInstance();
56 if (!backendRegistry.IsBackendRegistered(backendId))
57 {
58 std::stringstream ss;
59 ss << connectableLayer.GetName() << " is not supported on " << backendId
60 << " because this backend is not registered.";
61
62 outReasonIfUnsupported = ss.str();
63 return false;
64 }
65
66 auto backendFactory = backendRegistry.GetFactory(backendId);
67 auto backendObject = backendFactory();
68 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010069
telsoa014fcda012018-03-09 14:13:49 +000070 switch(layer.GetType())
71 {
72 case LayerType::Activation:
73 {
74 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
75 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010076 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010077 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010078 OverrideDataType(input, dataType),
79 OverrideDataType(output, dataType),
80 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010081 reason);
telsoa014fcda012018-03-09 14:13:49 +000082 break;
83 }
84 case LayerType::Addition:
85 {
86 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
87 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
88 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010089 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010090 OverrideDataType(input0, dataType),
91 OverrideDataType(input1, dataType),
92 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010093 reason);
telsoa014fcda012018-03-09 14:13:49 +000094 break;
95 }
96 case LayerType::BatchNormalization:
97 {
98 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
99 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100100 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
101 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
102 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
103 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
104 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100105 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100106 OverrideDataType(input, dataType),
107 OverrideDataType(output, dataType),
108 OverrideDataType(mean, dataType),
109 OverrideDataType(var, dataType),
110 OverrideDataType(beta, dataType),
111 OverrideDataType(gamma, dataType),
112 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100113 reason);
telsoa014fcda012018-03-09 14:13:49 +0000114 break;
115 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000116 case LayerType::BatchToSpaceNd:
117 {
118 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
119 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
120 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
121
122 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
123 OverrideDataType(output, dataType),
124 cLayer->GetParameters(),
125 reason);
126 break;
127 }
telsoa014fcda012018-03-09 14:13:49 +0000128 case LayerType::Constant:
129 {
130 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100131 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100132 break;
133 }
134 case LayerType::ConvertFp16ToFp32:
135 {
136 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
137 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100138 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100139 break;
140 }
141 case LayerType::ConvertFp32ToFp16:
142 {
143 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
144 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100145 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000146 break;
147 }
148 case LayerType::Convolution2d:
149 {
150 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100151
152 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
153 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100154 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100155 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
156
arovir01a6824102018-08-28 17:40:45 +0100157 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100158
arovir01a6824102018-08-28 17:40:45 +0100159 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100160 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100161 if (descriptor.m_BiasEnabled)
162 {
David Beck5eec11d2018-10-04 15:43:17 +0100163 biases =
164 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100165 }
166
David Beck33f0ae02018-10-18 15:13:56 +0100167 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100168 input,
169 output,
170 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100171 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100172 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100173 reason);
telsoa014fcda012018-03-09 14:13:49 +0000174 break;
175 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000176 case LayerType::Debug:
177 {
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
180
181 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
182 OverrideDataType(output, dataType),
183 reason);
184 break;
185 }
telsoa014fcda012018-03-09 14:13:49 +0000186 case LayerType::DepthwiseConvolution2d:
187 {
188 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
190 dataType);
191 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
192 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
193
telsoa01c577f2c2018-08-31 09:22:23 +0100194 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100195
196 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100197 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100198 if (descriptor.m_BiasEnabled)
199 {
David Beck5eec11d2018-10-04 15:43:17 +0100200 biases =
201 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100202 }
telsoa01c577f2c2018-08-31 09:22:23 +0100203
David Beck33f0ae02018-10-18 15:13:56 +0100204 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100205 input,
206 output,
207 descriptor,
208 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100209 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100210 reason);
telsoa014fcda012018-03-09 14:13:49 +0000211 break;
212 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000213 case LayerType::Dequantize:
214 {
215 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
216 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
217
218 result = layerSupportObject->IsDequantizeSupported(OverrideDataType(input, dataType),
219 OverrideDataType(output, DataType::Float32),
220 reason);
221 break;
222 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000223 case LayerType::DetectionPostProcess:
224 {
225 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
226 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
227 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
228 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
229 result = layerSupportObject->IsDetectionPostProcessSupported(input0,
230 input1,
231 descriptor,
232 reason);
233 break;
234 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000235 case LayerType::Equal:
236 {
237 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
238 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
239 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
240 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
241 OverrideDataType(input1, dataType),
242 OverrideDataType(output, dataType),
243 reason);
244 break;
245 }
telsoa014fcda012018-03-09 14:13:49 +0000246 case LayerType::FakeQuantization:
247 {
248 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
249 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100250 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
251 cLayer->GetParameters(),
252 reason);
telsoa014fcda012018-03-09 14:13:49 +0000253 break;
254 }
255 case LayerType::Floor:
256 {
257 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
258 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100259 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
260 OverrideDataType(output, dataType),
261 reason);
telsoa014fcda012018-03-09 14:13:49 +0000262 break;
263 }
264 case LayerType::FullyConnected:
265 {
266 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
267 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100268 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
269 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
270
271 TensorInfo biasInfo;
272 const TensorInfo * biasInfoPtr = nullptr;
273 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
274 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
275 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
276
277 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
278 if (descriptor.m_BiasEnabled)
279 {
280 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
281 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
282 biasInfoPtr = &biasInfo;
283 }
284 else
285 {
286 // If biases are not enabled pass a dummy tensorinfo for the validation
287 switch(input.GetDataType())
288 {
289 case DataType::Float16:
290 {
291 biasInfoPtr = &dummyFloat16Bias;
292 break;
293 }
294 case DataType::Float32:
295 {
296 biasInfoPtr = &dummyFloat32Bias;
297 break;
298 }
299 case DataType::QuantisedAsymm8:
300 {
301 biasInfoPtr = &dummyQA8Bias;
302 break;
303 }
304 default:
305 {
306 BOOST_ASSERT_MSG(false, "Unexpected bias type");
307 }
308 }
309 }
310
David Beck33f0ae02018-10-18 15:13:56 +0100311 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100312 OverrideDataType(input, dataType),
313 OverrideDataType(output, dataType),
314 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
315 *biasInfoPtr,
316 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100317 reason);
telsoa014fcda012018-03-09 14:13:49 +0000318 break;
319 }
narpra01b89b05f2019-01-16 09:53:09 +0000320 case LayerType::Gather:
321 {
322 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
323 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
324 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
325 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100326 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000327 OverrideDataType(output, dataType),
328 reason);
329 break;
330 }
telsoa014fcda012018-03-09 14:13:49 +0000331 case LayerType::Input:
332 {
333 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100334 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000335 break;
336 }
337 case LayerType::L2Normalization:
338 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100339 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
340 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
341
telsoa014fcda012018-03-09 14:13:49 +0000342 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100343 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100344
David Beck33f0ae02018-10-18 15:13:56 +0100345 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100346 OverrideDataType(input, dataType),
347 OverrideDataType(output, dataType),
348 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100349 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100350 break;
351 }
352 case LayerType::Lstm:
353 {
354 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
355 const LstmDescriptor& descriptor = cLayer->GetParameters();
356
357 // All inputs.
358 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
359 dataType);
360 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
361 dataType);
362 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
363 dataType);
364 // All outputs
365 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
366 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
367 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
368 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
369
370 // Basic parameters
371 const TensorInfo& inputToForgetWeights
372 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
373 const TensorInfo& inputToCellWeights
374 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
375 const TensorInfo& inputToOutputWeights
376 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
377 const TensorInfo& recurrentToForgetWeights
378 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
379 const TensorInfo& recurrentToCellWeights
380 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
381 const TensorInfo& recurrentToOutputWeights
382 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
383 const TensorInfo& forgetGateBias
384 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
385 const TensorInfo& cellBias
386 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
387 const TensorInfo& outputGateBias
388 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
389
390 // Optional parameters
391 const TensorInfo* inputToInputWeights = nullptr;
392 const TensorInfo* recurrentToInputWeights = nullptr;
393 const TensorInfo* cellToInputWeights = nullptr;
394 const TensorInfo* inputGateBias = nullptr;
395 const TensorInfo* projectionWeights = nullptr;
396 const TensorInfo* projectionBias = nullptr;
397 const TensorInfo* cellToForgetWeights = nullptr;
398 const TensorInfo* cellToOutputWeights = nullptr;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100399 const TensorInfo* inputLayerNormWeights = nullptr;
400 const TensorInfo* forgetLayerNormWeights = nullptr;
401 const TensorInfo* cellLayerNormWeights = nullptr;
402 const TensorInfo* outputLayerNormWeights = nullptr;
telsoa01c577f2c2018-08-31 09:22:23 +0100403
404 TensorInfo optInputToInputWeights;
405 TensorInfo optRecurrentToInputWeights;
406 TensorInfo optCellToInputWeights;
407 TensorInfo optInputGateBias;
408 TensorInfo optProjectionWeights;
409 TensorInfo optProjectionBias;
410 TensorInfo optCellToForgetWeights;
411 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100412 TensorInfo optInputLayerNormWeights;
413 TensorInfo optForgetLayerNormWeights;
414 TensorInfo optCellLayerNormWeights;
415 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100416
417 if(!descriptor.m_CifgEnabled)
418 {
419 optInputToInputWeights =
420 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
421 inputToInputWeights = &optInputToInputWeights;
422
423 optRecurrentToInputWeights =
424 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
425 recurrentToInputWeights = &optRecurrentToInputWeights;
426 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
427 {
428 optCellToInputWeights =
429 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
430 cellToInputWeights = &optCellToInputWeights;
431 }
432 optInputGateBias =
433 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
434 inputGateBias = &optInputGateBias;
435 }
436
437 if(descriptor.m_ProjectionEnabled)
438 {
439 optProjectionWeights =
440 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
441 projectionWeights = &optProjectionWeights;
442 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
443 {
444 optProjectionBias =
445 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
446 projectionBias = &optProjectionBias;
447 }
448 }
449
450 if(descriptor.m_PeepholeEnabled)
451 {
452 optCellToForgetWeights =
453 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
454 cellToForgetWeights = &optCellToForgetWeights;
455 optCellToOutputWeights =
456 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
457 cellToOutputWeights = &optCellToOutputWeights;
458 }
459
Jan Eilers38e05bd2019-06-26 13:10:09 +0100460 if(descriptor.m_LayerNormEnabled)
461 {
462 optInputLayerNormWeights = OverrideDataType(
463 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
464 inputLayerNormWeights = &optInputLayerNormWeights;
465
466 optForgetLayerNormWeights = OverrideDataType(
467 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
468 forgetLayerNormWeights = &optForgetLayerNormWeights;
469
470 optCellLayerNormWeights = OverrideDataType(
471 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
472 cellLayerNormWeights = &optCellLayerNormWeights;
473
474 optOutputLayerNormWeights = OverrideDataType(
475 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
476 outputLayerNormWeights = &optOutputLayerNormWeights;
477 }
478
David Beck33f0ae02018-10-18 15:13:56 +0100479 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100480 input,
481 outputStateIn,
482 cellStateIn,
483 scratchBuffer,
484 outputStateOut,
485 cellStateOut,
486 output,
487 descriptor,
488 inputToForgetWeights,
489 inputToCellWeights,
490 inputToOutputWeights,
491 recurrentToForgetWeights,
492 recurrentToCellWeights,
493 recurrentToOutputWeights,
494 forgetGateBias,
495 cellBias,
496 outputGateBias,
497 inputToInputWeights,
498 recurrentToInputWeights,
499 cellToInputWeights,
500 inputGateBias,
501 projectionWeights,
502 projectionBias,
503 cellToForgetWeights,
504 cellToOutputWeights,
Jan Eilers38e05bd2019-06-26 13:10:09 +0100505 reason,
506 inputLayerNormWeights,
507 forgetLayerNormWeights,
508 cellLayerNormWeights,
509 outputLayerNormWeights);
telsoa014fcda012018-03-09 14:13:49 +0000510 break;
511 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000512 case LayerType::Maximum:
513 {
514 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
515 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
516 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
517
518 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
519 OverrideDataType(input1, dataType),
520 OverrideDataType(output, dataType),
521 reason);
522 break;
523 }
narpra01b89b05f2019-01-16 09:53:09 +0000524 case LayerType::MemCopy:
525 {
526 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
527 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000528
narpra01b89b05f2019-01-16 09:53:09 +0000529 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
530 OverrideDataType(output, dataType),
531 reason);
532 break;
533 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100534 case LayerType::Merge:
535 {
536 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
537 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
538 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
539
540 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
541 OverrideDataType(input1, dataType),
542 OverrideDataType(output, dataType),
543 reason);
544 break;
545 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100546 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000547 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100548 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000549
telsoa01c577f2c2018-08-31 09:22:23 +0100550 // Get vector of all inputs.
551 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000552 {
telsoa01c577f2c2018-08-31 09:22:23 +0100553 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000554 };
telsoa01c577f2c2018-08-31 09:22:23 +0100555 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
556 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
557 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000558
telsoa01c577f2c2018-08-31 09:22:23 +0100559 auto getTensorInfoPtr = [](const TensorInfo& info)
560 {
561 return &info;
562 };
563 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
564 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
565 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000566
Nikhil Raj8599a412018-11-19 14:51:07 +0000567 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
568
Jim Flynne242f2d2019-05-22 14:24:13 +0100569 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
570
571
telsoa014fcda012018-03-09 14:13:49 +0000572 break;
573 }
574 case LayerType::Multiplication:
575 {
576 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
577 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100578 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100579 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100580 OverrideDataType(input0, dataType),
581 OverrideDataType(input1, dataType),
582 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100583 reason);
telsoa014fcda012018-03-09 14:13:49 +0000584 break;
585 }
586 case LayerType::Normalization:
587 {
588 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
589 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
590 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100591 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
592 OverrideDataType(output, dataType),
593 cLayer->GetParameters(),
594 reason);
telsoa014fcda012018-03-09 14:13:49 +0000595 break;
596 }
597 case LayerType::Output:
598 {
599 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100600 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000601 break;
602 }
603 case LayerType::Permute:
604 {
605 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
606 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
607 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100608 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
609 OverrideDataType(output, dataType),
610 cLayer->GetParameters(),
611 reason);
telsoa014fcda012018-03-09 14:13:49 +0000612 break;
613 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100614 case LayerType::Pad:
615 {
616 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
617 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
618 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100619 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100620 OverrideDataType(input, dataType),
621 OverrideDataType(output, dataType),
622 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100623 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100624 break;
625 }
telsoa014fcda012018-03-09 14:13:49 +0000626 case LayerType::Pooling2d:
627 {
628 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
629 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
630 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100631 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
632 OverrideDataType(output, dataType),
633 cLayer->GetParameters(),
634 reason);
telsoa014fcda012018-03-09 14:13:49 +0000635 break;
636 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000637 case LayerType::PreCompiled:
638 {
639 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
640 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
641 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
642 cLayer->GetParameters(),
643 reason);
644 break;
645 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000646 case LayerType::Quantize:
647 {
648 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
649 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
650 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
651 break;
652 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100653 case LayerType::Division:
654 {
655 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
656 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
657 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100658 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100659 OverrideDataType(input0, dataType),
660 OverrideDataType(input1, dataType),
661 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100662 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100663 break;
664 }
telsoa014fcda012018-03-09 14:13:49 +0000665 case LayerType::Reshape:
666 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000667 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000668 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000669 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
670 cLayer->GetParameters(),
671 reason);
telsoa014fcda012018-03-09 14:13:49 +0000672 break;
673 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100674 case LayerType::Resize:
675 {
676 auto cLayer = boost::polymorphic_downcast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100677 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100678 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
679 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
680 OverrideDataType(output, dataType),
681 cLayer->GetParameters(),
682 reason);
683 break;
684 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000685 case LayerType::Rsqrt:
686 {
687 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
688 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
689 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
690 OverrideDataType(output, dataType),
691 reason);
692 break;
693 }
telsoa014fcda012018-03-09 14:13:49 +0000694 case LayerType::Softmax:
695 {
696 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
697 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100698 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100699 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
700 OverrideDataType(output, dataType),
701 cLayer->GetParameters(),
702 reason);
telsoa014fcda012018-03-09 14:13:49 +0000703 break;
704 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000705 case LayerType::SpaceToBatchNd:
706 {
707 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
708 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
709 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
710 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
711 OverrideDataType(output, dataType),
712 cLayer->GetParameters(),
713 reason);
714 break;
715 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100716 case LayerType::SpaceToDepth:
717 {
718 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
719
720 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
721 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
722
723 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
724 OverrideDataType(output, dataType),
725 cLayer->GetParameters(),
726 reason);
727 break;
728 }
telsoa014fcda012018-03-09 14:13:49 +0000729 case LayerType::Splitter:
730 {
731 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
732 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100733
734 // Get vector of all outputs.
735 auto getTensorInfo = [&dataType](const OutputSlot& slot)
736 {
737 return OverrideDataType(slot.GetTensorInfo(), dataType);
738 };
739 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
740 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
741 std::vector<TensorInfo> outputs(beginI, endI);
742
743 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
744
David Beck33f0ae02018-10-18 15:13:56 +0100745 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100746 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100747 cLayer->GetParameters(),
748 reason);
telsoa014fcda012018-03-09 14:13:49 +0000749 break;
750 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000751 case LayerType::StridedSlice:
752 {
753 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
754 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
755 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
756 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
757 OverrideDataType(output, dataType),
758 cLayer->GetParameters(),
759 reason);
760 break;
761 }
David Beckc2044fe2018-09-05 15:00:38 +0100762 case LayerType::Subtraction:
763 {
764 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
765 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
766 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100767 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100768 OverrideDataType(input0, dataType),
769 OverrideDataType(input1, dataType),
770 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100771 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100772 break;
773 }
Sadik Armaganeff363d2019-04-05 15:25:46 +0100774 case LayerType::Switch:
775 {
776 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
777 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
778 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
779 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
780 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
781 OverrideDataType(input1, dataType),
782 OverrideDataType(output0, dataType),
783 OverrideDataType(output1, dataType),
784 reason);
785 break;
786 }
narpra0132b90462018-09-13 11:07:48 +0100787 case LayerType::Mean:
788 {
789 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
790 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
791 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100792 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100793 OverrideDataType(input, dataType),
794 OverrideDataType(output, dataType),
795 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100796 reason);
narpra0132b90462018-09-13 11:07:48 +0100797 break;
798 }
kevmay0190539692018-11-29 08:40:19 +0000799 case LayerType::Minimum:
800 {
801 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
802 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
803 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
804 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
805 OverrideDataType(input1, dataType),
806 OverrideDataType(output, dataType),
807 reason);
808 break;
809 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000810 case LayerType::Greater:
811 {
812 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
813 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
814 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000815 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
816 OverrideDataType(input1, dataType),
817 OverrideDataType(output, DataType::Boolean),
818 reason);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000819 break;
820 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +0100821 case LayerType::Prelu:
822 {
823 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
824 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
825 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
826 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
827 OverrideDataType(alpha, dataType),
828 OverrideDataType(output, dataType),
829 reason);
830 break;
831 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +0100832 case LayerType::TransposeConvolution2d:
833 {
834 auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
835
836 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
837 dataType);
838 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
839
840 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
841
842 Optional<TensorInfo> biases;
843 if (descriptor.m_BiasEnabled)
844 {
845 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
846 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
847 GetBiasTypeFromWeightsType(dataType));
848 }
849
850 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
851 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
852
853 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
854 output,
855 descriptor,
856 weights,
857 biases,
858 reason);
859
860 break;
861 }
telsoa014fcda012018-03-09 14:13:49 +0000862 default:
863 {
864 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100865 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000866 result = false;
867 break;
868 }
869 }
telsoa014fcda012018-03-09 14:13:49 +0000870 return result;
871}
872
David Beckdcb751f2018-10-03 11:42:42 +0100873bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100874 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100875 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000876{
David Beckdcb751f2018-10-03 11:42:42 +0100877 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100878 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000879}
880
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000881// Default Implementations
882std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
883 const WorkloadInfo& info) const
884{
885 return std::unique_ptr<IWorkload>();
886}
887
888std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
889 const WorkloadInfo& info) const
890{
891 return std::unique_ptr<IWorkload>();
892}
893
894std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
895 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
896{
897 return std::unique_ptr<IWorkload>();
898}
899
900std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
901 const WorkloadInfo& Info) const
902{
903 return std::unique_ptr<IWorkload>();
904}
905
Jim Flynne242f2d2019-05-22 14:24:13 +0100906std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
Jim Flynn4ed6c832019-05-20 11:02:46 +0100907 const WorkloadInfo& info) const
908{
909 return std::unique_ptr<IWorkload>();
910}
911
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000912std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
913 const WorkloadInfo& info) const
914{
915 return std::unique_ptr<IWorkload>();
916}
917
918std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
919 const WorkloadInfo& info) const
920{
921 return std::unique_ptr<IWorkload>();
922}
923
924std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
925 const WorkloadInfo& info) const
926{
927 return std::unique_ptr<IWorkload>();
928}
929
930std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
931 const WorkloadInfo& info) const
932{
933 return std::unique_ptr<IWorkload>();
934}
935
936std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
937 const WorkloadInfo& info) const
938{
939 return std::unique_ptr<IWorkload>();
940}
941
942std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
943 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
944{
945 return std::unique_ptr<IWorkload>();
946}
947
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000948std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
949 const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const
950{
951 return std::unique_ptr<IWorkload>();
952}
953
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000954std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
955 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const
956{
957 return std::unique_ptr<IWorkload>();
958}
959
960std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
961 const WorkloadInfo& info) const
962{
963 return std::unique_ptr<IWorkload>();
964}
965
966std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
967 const WorkloadInfo& Info) const
968{
969 return std::unique_ptr<IWorkload>();
970}
971
972std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
973 const WorkloadInfo& info) const
974{
975 return std::unique_ptr<IWorkload>();
976}
977
978std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
979 const WorkloadInfo& info) const
980{
981 return std::unique_ptr<IWorkload>();
982}
983
984std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
985 const WorkloadInfo& info) const
986{
987 return std::unique_ptr<IWorkload>();
988}
989
990std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
991 const WorkloadInfo& info) const
992{
993 return std::unique_ptr<IWorkload>();
994}
995
996std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
997 const WorkloadInfo& info) const
998{
999 return std::unique_ptr<IWorkload>();
1000}
1001
1002std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
1003 const WorkloadInfo& info) const
1004{
1005 return std::unique_ptr<IWorkload>();
1006}
1007
1008std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
1009 const WorkloadInfo& info) const
1010{
1011 return std::unique_ptr<IWorkload>();
1012}
1013
1014std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
1015 const WorkloadInfo& info) const
1016{
1017 return std::unique_ptr<IWorkload>();
1018}
1019
1020std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
1021 const WorkloadInfo& Info) const
1022{
1023 return std::unique_ptr<IWorkload>();
1024}
1025
1026std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
1027 const WorkloadInfo& info) const
1028{
1029 return std::unique_ptr<IWorkload>();
1030}
1031
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001032std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
1033 const WorkloadInfo& info) const
1034{
1035 return std::unique_ptr<IWorkload>();
1036}
1037
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001038std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
1039 const WorkloadInfo& info) const
1040{
1041 return std::unique_ptr<IWorkload>();
1042}
1043
1044std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
1045 const WorkloadInfo& info) const
1046{
1047 return std::unique_ptr<IWorkload>();
1048}
1049
1050std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
1051 const WorkloadInfo& info) const
1052{
1053 return std::unique_ptr<IWorkload>();
1054}
1055
1056std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
1057 const WorkloadInfo& info) const
1058{
1059 return std::unique_ptr<IWorkload>();
1060}
1061
1062std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
1063 const WorkloadInfo& info) const
1064{
1065 return std::unique_ptr<IWorkload>();
1066}
1067
1068std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
1069 const WorkloadInfo& Info) const
1070{
1071 return std::unique_ptr<IWorkload>();
1072}
1073
1074std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
1075 const WorkloadInfo& info) const
1076{
1077 return std::unique_ptr<IWorkload>();
1078}
1079
1080std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
1081 const WorkloadInfo& info) const
1082{
1083 return std::unique_ptr<IWorkload>();
1084}
1085
1086std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
1087 const WorkloadInfo& info) const
1088{
1089 return std::unique_ptr<IWorkload>();
1090}
1091
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001092std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
1093 const WorkloadInfo &info) const
1094{
1095 return std::unique_ptr<IWorkload>();
1096}
1097
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001098std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
1099 const WorkloadInfo& Info) const
1100{
1101 return std::unique_ptr<IWorkload>();
1102}
1103
1104std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
1105 const WorkloadInfo& info) const
1106{
1107 return std::unique_ptr<IWorkload>();
1108}
1109
1110std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
1111 const WorkloadInfo& info) const
1112{
1113 return std::unique_ptr<IWorkload>();
1114}
1115
Teresa Charlina9075df2019-06-27 15:41:57 +01001116std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
1117 const WorkloadInfo& info) const
1118{
1119 return std::unique_ptr<IWorkload>();
1120}
1121
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001122std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
1123 const WorkloadInfo& info) const
1124{
1125 return std::unique_ptr<IWorkload>();
1126}
1127
1128std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
1129 const WorkloadInfo& info) const
1130{
1131 return std::unique_ptr<IWorkload>();
1132}
1133
1134std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
1135 const WorkloadInfo& info) const
1136{
1137 return std::unique_ptr<IWorkload>();
1138}
1139
1140std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
1141 const WorkloadInfo& info) const
1142{
1143 return std::unique_ptr<IWorkload>();
1144}
1145
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001146std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
1147 const WorkloadInfo& info) const
1148{
1149 return std::unique_ptr<IWorkload>();
1150}
1151
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001152std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
1153 const WorkloadInfo& Info) const
1154{
1155 return std::unique_ptr<IWorkload>();
1156}
1157
1158std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
1159 const WorkloadInfo& info) const
1160{
1161 return std::unique_ptr<IWorkload>();
1162}
1163
Sadik Armaganeff363d2019-04-05 15:25:46 +01001164std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
1165 const WorkloadInfo& info) const
1166{
1167 return std::unique_ptr<IWorkload>();
1168}
1169
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001170std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
1171 const TransposeConvolution2dQueueDescriptor& descriptor,
1172 const WorkloadInfo& info) const
1173{
1174 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001175}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001176
1177} // namepsace armnn