blob: 6cca174141e9d19677b77b69333bf9629ccc913a [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
7
8#include <Layer.hpp>
9#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010010
David Beckb4540be2018-09-24 13:18:27 +010011#include <armnn/Types.hpp>
12#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000013#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
David Beck111b5d92018-11-12 14:59:37 +000015#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/IBackendInternal.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/iterator/transform_iterator.hpp>
21
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000022#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000023#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024
telsoa014fcda012018-03-09 14:13:49 +000025namespace armnn
26{
27
telsoa01c577f2c2018-08-31 09:22:23 +010028namespace
29{
telsoa01c577f2c2018-08-31 09:22:23 +010030
David Beck29c75de2018-10-23 13:35:58 +010031const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
32{
33 if (!type)
34 {
35 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010036 }
37
David Beck29c75de2018-10-23 13:35:58 +010038 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010039}
40
David Beck29c75de2018-10-23 13:35:58 +010041Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
42{
43 if (!weightsType)
44 {
45 return weightsType;
46 }
47
48 switch(weightsType.value())
49 {
50 case DataType::Float16:
51 case DataType::Float32:
52 return weightsType;
53 case DataType::QuantisedAsymm8:
54 return DataType::Signed32;
55 default:
56 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
57 }
58 return EmptyOptional();
59}
60
61} // anonymous namespace
62
David Beck33f0ae02018-10-18 15:13:56 +010063bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010064 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010065 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010066 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000067{
David Beck33f0ae02018-10-18 15:13:56 +010068 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000069 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010070 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
71
David Beck111b5d92018-11-12 14:59:37 +000072 auto const& backendRegistry = BackendRegistryInstance();
73 if (!backendRegistry.IsBackendRegistered(backendId))
74 {
75 std::stringstream ss;
76 ss << connectableLayer.GetName() << " is not supported on " << backendId
77 << " because this backend is not registered.";
78
79 outReasonIfUnsupported = ss.str();
80 return false;
81 }
82
83 auto backendFactory = backendRegistry.GetFactory(backendId);
84 auto backendObject = backendFactory();
85 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010086
telsoa014fcda012018-03-09 14:13:49 +000087 switch(layer.GetType())
88 {
89 case LayerType::Activation:
90 {
91 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
92 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010093 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010094 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010095 OverrideDataType(input, dataType),
96 OverrideDataType(output, dataType),
97 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010098 reason);
telsoa014fcda012018-03-09 14:13:49 +000099 break;
100 }
101 case LayerType::Addition:
102 {
103 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
104 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100106 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100107 OverrideDataType(input0, dataType),
108 OverrideDataType(input1, dataType),
109 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100110 reason);
telsoa014fcda012018-03-09 14:13:49 +0000111 break;
112 }
113 case LayerType::BatchNormalization:
114 {
115 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100122 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
137 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
138
139 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
143 break;
144 }
telsoa014fcda012018-03-09 14:13:49 +0000145 case LayerType::Constant:
146 {
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100148 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100149 break;
150 }
151 case LayerType::ConvertFp16ToFp32:
152 {
153 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
154 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100155 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100156 break;
157 }
158 case LayerType::ConvertFp32ToFp16:
159 {
160 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100162 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000163 break;
164 }
165 case LayerType::Convolution2d:
166 {
167 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100168
169 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
170 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100171 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100172 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
173
arovir01a6824102018-08-28 17:40:45 +0100174 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100175
arovir01a6824102018-08-28 17:40:45 +0100176 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100177 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100178 if (descriptor.m_BiasEnabled)
179 {
David Beck5eec11d2018-10-04 15:43:17 +0100180 biases =
181 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100182 }
183
David Beck33f0ae02018-10-18 15:13:56 +0100184 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100185 input,
186 output,
187 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100188 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100189 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100190 reason);
telsoa014fcda012018-03-09 14:13:49 +0000191 break;
192 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000193 case LayerType::Debug:
194 {
Nattapat Chaimanowongac5aa1f2018-12-05 15:17:18 +0000195 auto cLayer = boost::polymorphic_downcast<const DebugLayer*>(&layer);
Nattapat Chaimanowongac5aa1f2018-12-05 15:17:18 +0000196
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000197 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
198 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
199
200 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
201 OverrideDataType(output, dataType),
Matteo Martincigh49124022019-01-11 13:25:59 +0000202 cLayer->GetParameters(),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000203 reason);
204 break;
205 }
telsoa014fcda012018-03-09 14:13:49 +0000206 case LayerType::DepthwiseConvolution2d:
207 {
208 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100209 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
210 dataType);
211 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
212 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
213
telsoa01c577f2c2018-08-31 09:22:23 +0100214 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100215
216 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100217 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100218 if (descriptor.m_BiasEnabled)
219 {
David Beck5eec11d2018-10-04 15:43:17 +0100220 biases =
221 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100222 }
telsoa01c577f2c2018-08-31 09:22:23 +0100223
David Beck33f0ae02018-10-18 15:13:56 +0100224 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100225 input,
226 output,
227 descriptor,
228 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100229 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100230 reason);
telsoa014fcda012018-03-09 14:13:49 +0000231 break;
232 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000233 case LayerType::DetectionPostProcess:
234 {
235 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
236 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
237 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
238 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
239 result = layerSupportObject->IsDetectionPostProcessSupported(input0,
240 input1,
241 descriptor,
242 reason);
243 break;
244 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000245 case LayerType::Equal:
246 {
247 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
248 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
249 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
250 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
251 OverrideDataType(input1, dataType),
252 OverrideDataType(output, dataType),
253 reason);
254 break;
255 }
telsoa014fcda012018-03-09 14:13:49 +0000256 case LayerType::FakeQuantization:
257 {
258 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
259 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100260 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
261 cLayer->GetParameters(),
262 reason);
telsoa014fcda012018-03-09 14:13:49 +0000263 break;
264 }
265 case LayerType::Floor:
266 {
267 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
268 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100269 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
270 OverrideDataType(output, dataType),
271 reason);
telsoa014fcda012018-03-09 14:13:49 +0000272 break;
273 }
274 case LayerType::FullyConnected:
275 {
276 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
277 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100278 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
279 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
280
281 TensorInfo biasInfo;
282 const TensorInfo * biasInfoPtr = nullptr;
283 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
284 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
285 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
286
287 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
288 if (descriptor.m_BiasEnabled)
289 {
290 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
291 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
292 biasInfoPtr = &biasInfo;
293 }
294 else
295 {
296 // If biases are not enabled pass a dummy tensorinfo for the validation
297 switch(input.GetDataType())
298 {
299 case DataType::Float16:
300 {
301 biasInfoPtr = &dummyFloat16Bias;
302 break;
303 }
304 case DataType::Float32:
305 {
306 biasInfoPtr = &dummyFloat32Bias;
307 break;
308 }
309 case DataType::QuantisedAsymm8:
310 {
311 biasInfoPtr = &dummyQA8Bias;
312 break;
313 }
314 default:
315 {
316 BOOST_ASSERT_MSG(false, "Unexpected bias type");
317 }
318 }
319 }
320
David Beck33f0ae02018-10-18 15:13:56 +0100321 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100322 OverrideDataType(input, dataType),
323 OverrideDataType(output, dataType),
324 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
325 *biasInfoPtr,
326 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100327 reason);
telsoa014fcda012018-03-09 14:13:49 +0000328 break;
329 }
narpra01b89b05f2019-01-16 09:53:09 +0000330 case LayerType::Gather:
331 {
332 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
333 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
334 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
335 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
336 OverrideDataType(input1, dataType),
337 OverrideDataType(output, dataType),
338 reason);
339 break;
340 }
telsoa014fcda012018-03-09 14:13:49 +0000341 case LayerType::Input:
342 {
343 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100344 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000345 break;
346 }
347 case LayerType::L2Normalization:
348 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100349 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
350 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
351
telsoa014fcda012018-03-09 14:13:49 +0000352 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100353 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100354
David Beck33f0ae02018-10-18 15:13:56 +0100355 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100356 OverrideDataType(input, dataType),
357 OverrideDataType(output, dataType),
358 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100359 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100360 break;
361 }
362 case LayerType::Lstm:
363 {
364 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
365 const LstmDescriptor& descriptor = cLayer->GetParameters();
366
367 // All inputs.
368 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
369 dataType);
370 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
371 dataType);
372 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
373 dataType);
374 // All outputs
375 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
376 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
377 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
378 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
379
380 // Basic parameters
381 const TensorInfo& inputToForgetWeights
382 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
383 const TensorInfo& inputToCellWeights
384 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
385 const TensorInfo& inputToOutputWeights
386 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
387 const TensorInfo& recurrentToForgetWeights
388 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
389 const TensorInfo& recurrentToCellWeights
390 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
391 const TensorInfo& recurrentToOutputWeights
392 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
393 const TensorInfo& forgetGateBias
394 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
395 const TensorInfo& cellBias
396 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
397 const TensorInfo& outputGateBias
398 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
399
400 // Optional parameters
401 const TensorInfo* inputToInputWeights = nullptr;
402 const TensorInfo* recurrentToInputWeights = nullptr;
403 const TensorInfo* cellToInputWeights = nullptr;
404 const TensorInfo* inputGateBias = nullptr;
405 const TensorInfo* projectionWeights = nullptr;
406 const TensorInfo* projectionBias = nullptr;
407 const TensorInfo* cellToForgetWeights = nullptr;
408 const TensorInfo* cellToOutputWeights = nullptr;
409
410 TensorInfo optInputToInputWeights;
411 TensorInfo optRecurrentToInputWeights;
412 TensorInfo optCellToInputWeights;
413 TensorInfo optInputGateBias;
414 TensorInfo optProjectionWeights;
415 TensorInfo optProjectionBias;
416 TensorInfo optCellToForgetWeights;
417 TensorInfo optCellToOutputWeights;
418
419 if(!descriptor.m_CifgEnabled)
420 {
421 optInputToInputWeights =
422 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
423 inputToInputWeights = &optInputToInputWeights;
424
425 optRecurrentToInputWeights =
426 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
427 recurrentToInputWeights = &optRecurrentToInputWeights;
428 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
429 {
430 optCellToInputWeights =
431 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
432 cellToInputWeights = &optCellToInputWeights;
433 }
434 optInputGateBias =
435 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
436 inputGateBias = &optInputGateBias;
437 }
438
439 if(descriptor.m_ProjectionEnabled)
440 {
441 optProjectionWeights =
442 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
443 projectionWeights = &optProjectionWeights;
444 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
445 {
446 optProjectionBias =
447 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
448 projectionBias = &optProjectionBias;
449 }
450 }
451
452 if(descriptor.m_PeepholeEnabled)
453 {
454 optCellToForgetWeights =
455 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
456 cellToForgetWeights = &optCellToForgetWeights;
457 optCellToOutputWeights =
458 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
459 cellToOutputWeights = &optCellToOutputWeights;
460 }
461
David Beck33f0ae02018-10-18 15:13:56 +0100462 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100463 input,
464 outputStateIn,
465 cellStateIn,
466 scratchBuffer,
467 outputStateOut,
468 cellStateOut,
469 output,
470 descriptor,
471 inputToForgetWeights,
472 inputToCellWeights,
473 inputToOutputWeights,
474 recurrentToForgetWeights,
475 recurrentToCellWeights,
476 recurrentToOutputWeights,
477 forgetGateBias,
478 cellBias,
479 outputGateBias,
480 inputToInputWeights,
481 recurrentToInputWeights,
482 cellToInputWeights,
483 inputGateBias,
484 projectionWeights,
485 projectionBias,
486 cellToForgetWeights,
487 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100488 reason);
telsoa014fcda012018-03-09 14:13:49 +0000489 break;
490 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000491 case LayerType::Maximum:
492 {
493 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
494 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
495 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
496
497 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
498 OverrideDataType(input1, dataType),
499 OverrideDataType(output, dataType),
500 reason);
501 break;
502 }
narpra01b89b05f2019-01-16 09:53:09 +0000503 case LayerType::MemCopy:
504 {
505 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
506 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000507
narpra01b89b05f2019-01-16 09:53:09 +0000508 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
509 OverrideDataType(output, dataType),
510 reason);
511 break;
512 }
telsoa014fcda012018-03-09 14:13:49 +0000513 case LayerType::Merger:
514 {
515 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
516
telsoa01c577f2c2018-08-31 09:22:23 +0100517 // Get vector of all inputs.
518 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000519 {
telsoa01c577f2c2018-08-31 09:22:23 +0100520 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000521 };
telsoa01c577f2c2018-08-31 09:22:23 +0100522 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
523 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
524 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000525
telsoa01c577f2c2018-08-31 09:22:23 +0100526 auto getTensorInfoPtr = [](const TensorInfo& info)
527 {
528 return &info;
529 };
530 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
531 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
532 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000533
Nikhil Raj8599a412018-11-19 14:51:07 +0000534 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
535
536 result = layerSupportObject->IsMergerSupported(inputPtrs, output, cLayer->GetParameters(), reason);
telsoa014fcda012018-03-09 14:13:49 +0000537 break;
538 }
539 case LayerType::Multiplication:
540 {
541 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
542 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100543 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100544 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100545 OverrideDataType(input0, dataType),
546 OverrideDataType(input1, dataType),
547 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100548 reason);
telsoa014fcda012018-03-09 14:13:49 +0000549 break;
550 }
551 case LayerType::Normalization:
552 {
553 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
554 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
555 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100556 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
557 OverrideDataType(output, dataType),
558 cLayer->GetParameters(),
559 reason);
telsoa014fcda012018-03-09 14:13:49 +0000560 break;
561 }
562 case LayerType::Output:
563 {
564 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100565 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000566 break;
567 }
568 case LayerType::Permute:
569 {
570 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
571 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
572 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100573 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
574 OverrideDataType(output, dataType),
575 cLayer->GetParameters(),
576 reason);
telsoa014fcda012018-03-09 14:13:49 +0000577 break;
578 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100579 case LayerType::Pad:
580 {
581 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
582 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
583 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100584 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100585 OverrideDataType(input, dataType),
586 OverrideDataType(output, dataType),
587 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100588 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100589 break;
590 }
telsoa014fcda012018-03-09 14:13:49 +0000591 case LayerType::Pooling2d:
592 {
593 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
594 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
595 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100596 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
597 OverrideDataType(output, dataType),
598 cLayer->GetParameters(),
599 reason);
telsoa014fcda012018-03-09 14:13:49 +0000600 break;
601 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000602 case LayerType::PreCompiled:
603 {
604 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
605 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
606 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
607 cLayer->GetParameters(),
608 reason);
609 break;
610 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100611 case LayerType::Division:
612 {
613 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
614 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
615 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100616 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100617 OverrideDataType(input0, dataType),
618 OverrideDataType(input1, dataType),
619 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100620 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100621 break;
622 }
telsoa014fcda012018-03-09 14:13:49 +0000623 case LayerType::Reshape:
624 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000625 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000626 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000627 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
628 cLayer->GetParameters(),
629 reason);
telsoa014fcda012018-03-09 14:13:49 +0000630 break;
631 }
632 case LayerType::ResizeBilinear:
633 {
634 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganc625f002018-12-17 11:32:16 +0000635 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
636 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType),
637 OverrideDataType(output, dataType),
638 reason);
telsoa014fcda012018-03-09 14:13:49 +0000639 break;
640 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000641 case LayerType::Rsqrt:
642 {
643 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
644 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
645 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
646 OverrideDataType(output, dataType),
647 reason);
648 break;
649 }
telsoa014fcda012018-03-09 14:13:49 +0000650 case LayerType::Softmax:
651 {
652 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
653 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100654 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100655 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
656 OverrideDataType(output, dataType),
657 cLayer->GetParameters(),
658 reason);
telsoa014fcda012018-03-09 14:13:49 +0000659 break;
660 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000661 case LayerType::SpaceToBatchNd:
662 {
663 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
664 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
665 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
666 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
667 OverrideDataType(output, dataType),
668 cLayer->GetParameters(),
669 reason);
670 break;
671 }
telsoa014fcda012018-03-09 14:13:49 +0000672 case LayerType::Splitter:
673 {
674 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
675 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100676 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
677 cLayer->GetParameters(),
678 reason);
telsoa014fcda012018-03-09 14:13:49 +0000679 break;
680 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000681 case LayerType::StridedSlice:
682 {
683 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
684 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
685 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
686 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
687 OverrideDataType(output, dataType),
688 cLayer->GetParameters(),
689 reason);
690 break;
691 }
David Beckc2044fe2018-09-05 15:00:38 +0100692 case LayerType::Subtraction:
693 {
694 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
695 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
696 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100697 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100698 OverrideDataType(input0, dataType),
699 OverrideDataType(input1, dataType),
700 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100701 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100702 break;
703 }
narpra0132b90462018-09-13 11:07:48 +0100704 case LayerType::Mean:
705 {
706 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
707 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
708 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100709 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100710 OverrideDataType(input, dataType),
711 OverrideDataType(output, dataType),
712 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100713 reason);
narpra0132b90462018-09-13 11:07:48 +0100714 break;
715 }
kevmay0190539692018-11-29 08:40:19 +0000716 case LayerType::Minimum:
717 {
718 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
719 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
720 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
721 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
722 OverrideDataType(input1, dataType),
723 OverrideDataType(output, dataType),
724 reason);
725 break;
726 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000727 case LayerType::Greater:
728 {
729 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
730 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
731 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
kevmay01eed85922019-01-28 08:37:25 +0000732 result = layerSupportObject->IsGreaterSupported(input0, input1, output, reason);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000733 break;
734 }
telsoa014fcda012018-03-09 14:13:49 +0000735 default:
736 {
737 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100738 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000739 result = false;
740 break;
741 }
742 }
telsoa014fcda012018-03-09 14:13:49 +0000743 return result;
744}
745
David Beckdcb751f2018-10-03 11:42:42 +0100746bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100747 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100748 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000749{
David Beckdcb751f2018-10-03 11:42:42 +0100750 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100751 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000752}
753
surmeh013537c2c2018-05-18 16:31:43 +0100754}