blob: 3b8a7d8f7ffba1db2f7957b3664c82cdfa9274b8 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
7
8#include <Layer.hpp>
9#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010010
David Beckb4540be2018-09-24 13:18:27 +010011#include <armnn/Types.hpp>
12#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000013#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
David Beck111b5d92018-11-12 14:59:37 +000015#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/IBackendInternal.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/iterator/transform_iterator.hpp>
21
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000022#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000023#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024
telsoa014fcda012018-03-09 14:13:49 +000025namespace armnn
26{
27
telsoa01c577f2c2018-08-31 09:22:23 +010028namespace
29{
telsoa01c577f2c2018-08-31 09:22:23 +010030
David Beck29c75de2018-10-23 13:35:58 +010031const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
32{
33 if (!type)
34 {
35 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010036 }
37
David Beck29c75de2018-10-23 13:35:58 +010038 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010039}
40
David Beck29c75de2018-10-23 13:35:58 +010041Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
42{
43 if (!weightsType)
44 {
45 return weightsType;
46 }
47
48 switch(weightsType.value())
49 {
50 case DataType::Float16:
51 case DataType::Float32:
52 return weightsType;
53 case DataType::QuantisedAsymm8:
54 return DataType::Signed32;
55 default:
56 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
57 }
58 return EmptyOptional();
59}
60
61} // anonymous namespace
62
David Beck33f0ae02018-10-18 15:13:56 +010063bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010064 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010065 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010066 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000067{
David Beck33f0ae02018-10-18 15:13:56 +010068 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000069 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010070 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
71
David Beck111b5d92018-11-12 14:59:37 +000072 auto const& backendRegistry = BackendRegistryInstance();
73 if (!backendRegistry.IsBackendRegistered(backendId))
74 {
75 std::stringstream ss;
76 ss << connectableLayer.GetName() << " is not supported on " << backendId
77 << " because this backend is not registered.";
78
79 outReasonIfUnsupported = ss.str();
80 return false;
81 }
82
83 auto backendFactory = backendRegistry.GetFactory(backendId);
84 auto backendObject = backendFactory();
85 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010086
telsoa014fcda012018-03-09 14:13:49 +000087 switch(layer.GetType())
88 {
89 case LayerType::Activation:
90 {
91 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
92 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010093 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010094 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010095 OverrideDataType(input, dataType),
96 OverrideDataType(output, dataType),
97 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010098 reason);
telsoa014fcda012018-03-09 14:13:49 +000099 break;
100 }
101 case LayerType::Addition:
102 {
103 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
104 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100106 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100107 OverrideDataType(input0, dataType),
108 OverrideDataType(input1, dataType),
109 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100110 reason);
telsoa014fcda012018-03-09 14:13:49 +0000111 break;
112 }
113 case LayerType::BatchNormalization:
114 {
115 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100122 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
137 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
138
139 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
143 break;
144 }
telsoa014fcda012018-03-09 14:13:49 +0000145 case LayerType::Constant:
146 {
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100148 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100149 break;
150 }
151 case LayerType::ConvertFp16ToFp32:
152 {
153 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
154 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100155 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100156 break;
157 }
158 case LayerType::ConvertFp32ToFp16:
159 {
160 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100162 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000163 break;
164 }
165 case LayerType::Convolution2d:
166 {
167 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100168
169 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
170 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100171 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100172 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
173
arovir01a6824102018-08-28 17:40:45 +0100174 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100175
arovir01a6824102018-08-28 17:40:45 +0100176 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100177 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100178 if (descriptor.m_BiasEnabled)
179 {
David Beck5eec11d2018-10-04 15:43:17 +0100180 biases =
181 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100182 }
183
David Beck33f0ae02018-10-18 15:13:56 +0100184 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100185 input,
186 output,
187 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100188 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100189 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100190 reason);
telsoa014fcda012018-03-09 14:13:49 +0000191 break;
192 }
193 case LayerType::MemCopy:
194 {
telsoa01c577f2c2018-08-31 09:22:23 +0100195 // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
196 // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
David Beck33f0ae02018-10-18 15:13:56 +0100197 result = backendId == Compute::CpuRef || backendId == Compute::Undefined
198 || backendId == Compute::CpuAcc || backendId == Compute::GpuAcc;
199 reason.value() = "Unsupported backend type";
telsoa014fcda012018-03-09 14:13:49 +0000200 break;
201 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000202 case LayerType::Debug:
203 {
Nattapat Chaimanowongac5aa1f2018-12-05 15:17:18 +0000204 auto cLayer = boost::polymorphic_downcast<const DebugLayer*>(&layer);
205 const DebugDescriptor& descriptor = cLayer->GetParameters();
206
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000207 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
208 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
209
210 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
211 OverrideDataType(output, dataType),
Nattapat Chaimanowongac5aa1f2018-12-05 15:17:18 +0000212 descriptor,
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000213 reason);
214 break;
215 }
telsoa014fcda012018-03-09 14:13:49 +0000216 case LayerType::DepthwiseConvolution2d:
217 {
218 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100219 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
220 dataType);
221 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
222 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
223
telsoa01c577f2c2018-08-31 09:22:23 +0100224 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100225
226 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100227 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100228 if (descriptor.m_BiasEnabled)
229 {
David Beck5eec11d2018-10-04 15:43:17 +0100230 biases =
231 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100232 }
telsoa01c577f2c2018-08-31 09:22:23 +0100233
David Beck33f0ae02018-10-18 15:13:56 +0100234 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100235 input,
236 output,
237 descriptor,
238 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100239 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100240 reason);
telsoa014fcda012018-03-09 14:13:49 +0000241 break;
242 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000243 case LayerType::Equal:
244 {
245 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
246 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
247 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
248 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
249 OverrideDataType(input1, dataType),
250 OverrideDataType(output, dataType),
251 reason);
252 break;
253 }
telsoa014fcda012018-03-09 14:13:49 +0000254 case LayerType::FakeQuantization:
255 {
256 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
257 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100258 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
259 cLayer->GetParameters(),
260 reason);
telsoa014fcda012018-03-09 14:13:49 +0000261 break;
262 }
263 case LayerType::Floor:
264 {
265 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
266 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100267 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
268 OverrideDataType(output, dataType),
269 reason);
telsoa014fcda012018-03-09 14:13:49 +0000270 break;
271 }
272 case LayerType::FullyConnected:
273 {
274 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
275 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100276 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
277 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
278
279 TensorInfo biasInfo;
280 const TensorInfo * biasInfoPtr = nullptr;
281 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
282 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
283 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
284
285 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
286 if (descriptor.m_BiasEnabled)
287 {
288 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
289 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
290 biasInfoPtr = &biasInfo;
291 }
292 else
293 {
294 // If biases are not enabled pass a dummy tensorinfo for the validation
295 switch(input.GetDataType())
296 {
297 case DataType::Float16:
298 {
299 biasInfoPtr = &dummyFloat16Bias;
300 break;
301 }
302 case DataType::Float32:
303 {
304 biasInfoPtr = &dummyFloat32Bias;
305 break;
306 }
307 case DataType::QuantisedAsymm8:
308 {
309 biasInfoPtr = &dummyQA8Bias;
310 break;
311 }
312 default:
313 {
314 BOOST_ASSERT_MSG(false, "Unexpected bias type");
315 }
316 }
317 }
318
David Beck33f0ae02018-10-18 15:13:56 +0100319 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100320 OverrideDataType(input, dataType),
321 OverrideDataType(output, dataType),
322 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
323 *biasInfoPtr,
324 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100325 reason);
telsoa014fcda012018-03-09 14:13:49 +0000326 break;
327 }
328 case LayerType::Input:
329 {
330 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100331 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000332 break;
333 }
334 case LayerType::L2Normalization:
335 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100336 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
337 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
338
telsoa014fcda012018-03-09 14:13:49 +0000339 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100340 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100341
David Beck33f0ae02018-10-18 15:13:56 +0100342 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100343 OverrideDataType(input, dataType),
344 OverrideDataType(output, dataType),
345 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100346 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100347 break;
348 }
349 case LayerType::Lstm:
350 {
351 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
352 const LstmDescriptor& descriptor = cLayer->GetParameters();
353
354 // All inputs.
355 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
356 dataType);
357 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
358 dataType);
359 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
360 dataType);
361 // All outputs
362 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
363 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
364 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
365 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
366
367 // Basic parameters
368 const TensorInfo& inputToForgetWeights
369 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
370 const TensorInfo& inputToCellWeights
371 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
372 const TensorInfo& inputToOutputWeights
373 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
374 const TensorInfo& recurrentToForgetWeights
375 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
376 const TensorInfo& recurrentToCellWeights
377 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
378 const TensorInfo& recurrentToOutputWeights
379 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
380 const TensorInfo& forgetGateBias
381 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
382 const TensorInfo& cellBias
383 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
384 const TensorInfo& outputGateBias
385 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
386
387 // Optional parameters
388 const TensorInfo* inputToInputWeights = nullptr;
389 const TensorInfo* recurrentToInputWeights = nullptr;
390 const TensorInfo* cellToInputWeights = nullptr;
391 const TensorInfo* inputGateBias = nullptr;
392 const TensorInfo* projectionWeights = nullptr;
393 const TensorInfo* projectionBias = nullptr;
394 const TensorInfo* cellToForgetWeights = nullptr;
395 const TensorInfo* cellToOutputWeights = nullptr;
396
397 TensorInfo optInputToInputWeights;
398 TensorInfo optRecurrentToInputWeights;
399 TensorInfo optCellToInputWeights;
400 TensorInfo optInputGateBias;
401 TensorInfo optProjectionWeights;
402 TensorInfo optProjectionBias;
403 TensorInfo optCellToForgetWeights;
404 TensorInfo optCellToOutputWeights;
405
406 if(!descriptor.m_CifgEnabled)
407 {
408 optInputToInputWeights =
409 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
410 inputToInputWeights = &optInputToInputWeights;
411
412 optRecurrentToInputWeights =
413 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
414 recurrentToInputWeights = &optRecurrentToInputWeights;
415 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
416 {
417 optCellToInputWeights =
418 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
419 cellToInputWeights = &optCellToInputWeights;
420 }
421 optInputGateBias =
422 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
423 inputGateBias = &optInputGateBias;
424 }
425
426 if(descriptor.m_ProjectionEnabled)
427 {
428 optProjectionWeights =
429 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
430 projectionWeights = &optProjectionWeights;
431 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
432 {
433 optProjectionBias =
434 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
435 projectionBias = &optProjectionBias;
436 }
437 }
438
439 if(descriptor.m_PeepholeEnabled)
440 {
441 optCellToForgetWeights =
442 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
443 cellToForgetWeights = &optCellToForgetWeights;
444 optCellToOutputWeights =
445 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
446 cellToOutputWeights = &optCellToOutputWeights;
447 }
448
David Beck33f0ae02018-10-18 15:13:56 +0100449 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100450 input,
451 outputStateIn,
452 cellStateIn,
453 scratchBuffer,
454 outputStateOut,
455 cellStateOut,
456 output,
457 descriptor,
458 inputToForgetWeights,
459 inputToCellWeights,
460 inputToOutputWeights,
461 recurrentToForgetWeights,
462 recurrentToCellWeights,
463 recurrentToOutputWeights,
464 forgetGateBias,
465 cellBias,
466 outputGateBias,
467 inputToInputWeights,
468 recurrentToInputWeights,
469 cellToInputWeights,
470 inputGateBias,
471 projectionWeights,
472 projectionBias,
473 cellToForgetWeights,
474 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100475 reason);
telsoa014fcda012018-03-09 14:13:49 +0000476 break;
477 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000478 case LayerType::Maximum:
479 {
480 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
481 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
482 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
483
484 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
485 OverrideDataType(input1, dataType),
486 OverrideDataType(output, dataType),
487 reason);
488 break;
489 }
telsoa014fcda012018-03-09 14:13:49 +0000490 case LayerType::Merger:
491 {
492 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
493
telsoa01c577f2c2018-08-31 09:22:23 +0100494 // Get vector of all inputs.
495 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000496 {
telsoa01c577f2c2018-08-31 09:22:23 +0100497 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000498 };
telsoa01c577f2c2018-08-31 09:22:23 +0100499 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
500 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
501 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000502
telsoa01c577f2c2018-08-31 09:22:23 +0100503 auto getTensorInfoPtr = [](const TensorInfo& info)
504 {
505 return &info;
506 };
507 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
508 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
509 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000510
Nikhil Raj8599a412018-11-19 14:51:07 +0000511 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
512
513 result = layerSupportObject->IsMergerSupported(inputPtrs, output, cLayer->GetParameters(), reason);
telsoa014fcda012018-03-09 14:13:49 +0000514 break;
515 }
516 case LayerType::Multiplication:
517 {
518 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
519 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100520 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100521 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100522 OverrideDataType(input0, dataType),
523 OverrideDataType(input1, dataType),
524 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100525 reason);
telsoa014fcda012018-03-09 14:13:49 +0000526 break;
527 }
528 case LayerType::Normalization:
529 {
530 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
531 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
532 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100533 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
534 OverrideDataType(output, dataType),
535 cLayer->GetParameters(),
536 reason);
telsoa014fcda012018-03-09 14:13:49 +0000537 break;
538 }
539 case LayerType::Output:
540 {
541 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100542 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000543 break;
544 }
545 case LayerType::Permute:
546 {
547 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
548 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
549 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100550 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
551 OverrideDataType(output, dataType),
552 cLayer->GetParameters(),
553 reason);
telsoa014fcda012018-03-09 14:13:49 +0000554 break;
555 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100556 case LayerType::Pad:
557 {
558 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
559 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
560 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100561 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100562 OverrideDataType(input, dataType),
563 OverrideDataType(output, dataType),
564 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100565 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100566 break;
567 }
telsoa014fcda012018-03-09 14:13:49 +0000568 case LayerType::Pooling2d:
569 {
570 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
571 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
572 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100573 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
574 OverrideDataType(output, dataType),
575 cLayer->GetParameters(),
576 reason);
telsoa014fcda012018-03-09 14:13:49 +0000577 break;
578 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100579 case LayerType::Division:
580 {
581 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
582 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
583 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100584 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100585 OverrideDataType(input0, dataType),
586 OverrideDataType(input1, dataType),
587 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100588 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100589 break;
590 }
telsoa014fcda012018-03-09 14:13:49 +0000591 case LayerType::Reshape:
592 {
593 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100594 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000595 break;
596 }
597 case LayerType::ResizeBilinear:
598 {
599 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100600 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000601 break;
602 }
603 case LayerType::Softmax:
604 {
605 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
606 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100607 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100608 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
609 OverrideDataType(output, dataType),
610 cLayer->GetParameters(),
611 reason);
telsoa014fcda012018-03-09 14:13:49 +0000612 break;
613 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000614 case LayerType::SpaceToBatchNd:
615 {
616 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
617 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
618 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
619 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
620 OverrideDataType(output, dataType),
621 cLayer->GetParameters(),
622 reason);
623 break;
624 }
telsoa014fcda012018-03-09 14:13:49 +0000625 case LayerType::Splitter:
626 {
627 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
628 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100629 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
630 cLayer->GetParameters(),
631 reason);
telsoa014fcda012018-03-09 14:13:49 +0000632 break;
633 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000634 case LayerType::StridedSlice:
635 {
636 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
637 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
638 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
639 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
640 OverrideDataType(output, dataType),
641 cLayer->GetParameters(),
642 reason);
643 break;
644 }
David Beckc2044fe2018-09-05 15:00:38 +0100645 case LayerType::Subtraction:
646 {
647 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
648 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
649 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100650 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100651 OverrideDataType(input0, dataType),
652 OverrideDataType(input1, dataType),
653 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100654 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100655 break;
656 }
narpra0132b90462018-09-13 11:07:48 +0100657 case LayerType::Mean:
658 {
659 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
660 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
661 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100662 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100663 OverrideDataType(input, dataType),
664 OverrideDataType(output, dataType),
665 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100666 reason);
narpra0132b90462018-09-13 11:07:48 +0100667 break;
668 }
kevmay0190539692018-11-29 08:40:19 +0000669 case LayerType::Minimum:
670 {
671 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
672 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
673 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
674 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
675 OverrideDataType(input1, dataType),
676 OverrideDataType(output, dataType),
677 reason);
678 break;
679 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000680 case LayerType::Greater:
681 {
682 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
683 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
684 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
685 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
686 OverrideDataType(input1, dataType),
687 OverrideDataType(output, dataType),
688 reason);
689 break;
690 }
telsoa014fcda012018-03-09 14:13:49 +0000691 default:
692 {
693 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100694 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000695 result = false;
696 break;
697 }
698 }
telsoa014fcda012018-03-09 14:13:49 +0000699 return result;
700}
701
David Beckdcb751f2018-10-03 11:42:42 +0100702bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100703 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100704 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000705{
David Beckdcb751f2018-10-03 11:42:42 +0100706 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100707 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000708}
709
surmeh013537c2c2018-05-18 16:31:43 +0100710}