blob: 0f015bd540b47a1c071e59418a38147a1a8379f7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
7
8#include <Layer.hpp>
9#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010010
David Beckb4540be2018-09-24 13:18:27 +010011#include <armnn/Types.hpp>
12#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000013#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
David Beck111b5d92018-11-12 14:59:37 +000015#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/IBackendInternal.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/iterator/transform_iterator.hpp>
21
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000022#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000023#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024
telsoa014fcda012018-03-09 14:13:49 +000025namespace armnn
26{
27
telsoa01c577f2c2018-08-31 09:22:23 +010028namespace
29{
telsoa01c577f2c2018-08-31 09:22:23 +010030
David Beck29c75de2018-10-23 13:35:58 +010031const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
32{
33 if (!type)
34 {
35 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010036 }
37
David Beck29c75de2018-10-23 13:35:58 +010038 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010039}
40
David Beck29c75de2018-10-23 13:35:58 +010041Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
42{
43 if (!weightsType)
44 {
45 return weightsType;
46 }
47
48 switch(weightsType.value())
49 {
50 case DataType::Float16:
51 case DataType::Float32:
52 return weightsType;
53 case DataType::QuantisedAsymm8:
54 return DataType::Signed32;
55 default:
56 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
57 }
58 return EmptyOptional();
59}
60
61} // anonymous namespace
62
David Beck33f0ae02018-10-18 15:13:56 +010063bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010064 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010065 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010066 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000067{
David Beck33f0ae02018-10-18 15:13:56 +010068 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000069 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010070 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
71
David Beck111b5d92018-11-12 14:59:37 +000072 auto const& backendRegistry = BackendRegistryInstance();
73 if (!backendRegistry.IsBackendRegistered(backendId))
74 {
75 std::stringstream ss;
76 ss << connectableLayer.GetName() << " is not supported on " << backendId
77 << " because this backend is not registered.";
78
79 outReasonIfUnsupported = ss.str();
80 return false;
81 }
82
83 auto backendFactory = backendRegistry.GetFactory(backendId);
84 auto backendObject = backendFactory();
85 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010086
telsoa014fcda012018-03-09 14:13:49 +000087 switch(layer.GetType())
88 {
89 case LayerType::Activation:
90 {
91 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
92 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010093 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010094 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010095 OverrideDataType(input, dataType),
96 OverrideDataType(output, dataType),
97 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010098 reason);
telsoa014fcda012018-03-09 14:13:49 +000099 break;
100 }
101 case LayerType::Addition:
102 {
103 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
104 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100106 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100107 OverrideDataType(input0, dataType),
108 OverrideDataType(input1, dataType),
109 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100110 reason);
telsoa014fcda012018-03-09 14:13:49 +0000111 break;
112 }
113 case LayerType::BatchNormalization:
114 {
115 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100122 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
137 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
138
139 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
143 break;
144 }
telsoa014fcda012018-03-09 14:13:49 +0000145 case LayerType::Constant:
146 {
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100148 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100149 break;
150 }
151 case LayerType::ConvertFp16ToFp32:
152 {
153 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
154 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100155 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100156 break;
157 }
158 case LayerType::ConvertFp32ToFp16:
159 {
160 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100162 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000163 break;
164 }
165 case LayerType::Convolution2d:
166 {
167 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100168
169 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
170 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100171 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100172 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
173
arovir01a6824102018-08-28 17:40:45 +0100174 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100175
arovir01a6824102018-08-28 17:40:45 +0100176 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100177 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100178 if (descriptor.m_BiasEnabled)
179 {
David Beck5eec11d2018-10-04 15:43:17 +0100180 biases =
181 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100182 }
183
David Beck33f0ae02018-10-18 15:13:56 +0100184 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100185 input,
186 output,
187 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100188 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100189 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100190 reason);
telsoa014fcda012018-03-09 14:13:49 +0000191 break;
192 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000193 case LayerType::Debug:
194 {
Nattapat Chaimanowongac5aa1f2018-12-05 15:17:18 +0000195 auto cLayer = boost::polymorphic_downcast<const DebugLayer*>(&layer);
Nattapat Chaimanowongac5aa1f2018-12-05 15:17:18 +0000196
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000197 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
198 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
199
200 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
201 OverrideDataType(output, dataType),
Matteo Martincigh49124022019-01-11 13:25:59 +0000202 cLayer->GetParameters(),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000203 reason);
204 break;
205 }
telsoa014fcda012018-03-09 14:13:49 +0000206 case LayerType::DepthwiseConvolution2d:
207 {
208 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100209 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
210 dataType);
211 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
212 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
213
telsoa01c577f2c2018-08-31 09:22:23 +0100214 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100215
216 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100217 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100218 if (descriptor.m_BiasEnabled)
219 {
David Beck5eec11d2018-10-04 15:43:17 +0100220 biases =
221 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100222 }
telsoa01c577f2c2018-08-31 09:22:23 +0100223
David Beck33f0ae02018-10-18 15:13:56 +0100224 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100225 input,
226 output,
227 descriptor,
228 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100229 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100230 reason);
telsoa014fcda012018-03-09 14:13:49 +0000231 break;
232 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000233 case LayerType::Equal:
234 {
235 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
236 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
237 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
238 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
239 OverrideDataType(input1, dataType),
240 OverrideDataType(output, dataType),
241 reason);
242 break;
243 }
telsoa014fcda012018-03-09 14:13:49 +0000244 case LayerType::FakeQuantization:
245 {
246 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
247 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100248 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
249 cLayer->GetParameters(),
250 reason);
telsoa014fcda012018-03-09 14:13:49 +0000251 break;
252 }
253 case LayerType::Floor:
254 {
255 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
256 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100257 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
258 OverrideDataType(output, dataType),
259 reason);
telsoa014fcda012018-03-09 14:13:49 +0000260 break;
261 }
262 case LayerType::FullyConnected:
263 {
264 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
265 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100266 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
267 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
268
269 TensorInfo biasInfo;
270 const TensorInfo * biasInfoPtr = nullptr;
271 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
272 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
273 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
274
275 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
276 if (descriptor.m_BiasEnabled)
277 {
278 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
279 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
280 biasInfoPtr = &biasInfo;
281 }
282 else
283 {
284 // If biases are not enabled pass a dummy tensorinfo for the validation
285 switch(input.GetDataType())
286 {
287 case DataType::Float16:
288 {
289 biasInfoPtr = &dummyFloat16Bias;
290 break;
291 }
292 case DataType::Float32:
293 {
294 biasInfoPtr = &dummyFloat32Bias;
295 break;
296 }
297 case DataType::QuantisedAsymm8:
298 {
299 biasInfoPtr = &dummyQA8Bias;
300 break;
301 }
302 default:
303 {
304 BOOST_ASSERT_MSG(false, "Unexpected bias type");
305 }
306 }
307 }
308
David Beck33f0ae02018-10-18 15:13:56 +0100309 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100310 OverrideDataType(input, dataType),
311 OverrideDataType(output, dataType),
312 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
313 *biasInfoPtr,
314 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100315 reason);
telsoa014fcda012018-03-09 14:13:49 +0000316 break;
317 }
318 case LayerType::Input:
319 {
320 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100321 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000322 break;
323 }
324 case LayerType::L2Normalization:
325 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100326 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
327 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
328
telsoa014fcda012018-03-09 14:13:49 +0000329 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100330 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100331
David Beck33f0ae02018-10-18 15:13:56 +0100332 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100333 OverrideDataType(input, dataType),
334 OverrideDataType(output, dataType),
335 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100336 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100337 break;
338 }
339 case LayerType::Lstm:
340 {
341 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
342 const LstmDescriptor& descriptor = cLayer->GetParameters();
343
344 // All inputs.
345 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
346 dataType);
347 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
348 dataType);
349 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
350 dataType);
351 // All outputs
352 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
353 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
354 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
355 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
356
357 // Basic parameters
358 const TensorInfo& inputToForgetWeights
359 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
360 const TensorInfo& inputToCellWeights
361 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
362 const TensorInfo& inputToOutputWeights
363 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
364 const TensorInfo& recurrentToForgetWeights
365 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
366 const TensorInfo& recurrentToCellWeights
367 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
368 const TensorInfo& recurrentToOutputWeights
369 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
370 const TensorInfo& forgetGateBias
371 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
372 const TensorInfo& cellBias
373 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
374 const TensorInfo& outputGateBias
375 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
376
377 // Optional parameters
378 const TensorInfo* inputToInputWeights = nullptr;
379 const TensorInfo* recurrentToInputWeights = nullptr;
380 const TensorInfo* cellToInputWeights = nullptr;
381 const TensorInfo* inputGateBias = nullptr;
382 const TensorInfo* projectionWeights = nullptr;
383 const TensorInfo* projectionBias = nullptr;
384 const TensorInfo* cellToForgetWeights = nullptr;
385 const TensorInfo* cellToOutputWeights = nullptr;
386
387 TensorInfo optInputToInputWeights;
388 TensorInfo optRecurrentToInputWeights;
389 TensorInfo optCellToInputWeights;
390 TensorInfo optInputGateBias;
391 TensorInfo optProjectionWeights;
392 TensorInfo optProjectionBias;
393 TensorInfo optCellToForgetWeights;
394 TensorInfo optCellToOutputWeights;
395
396 if(!descriptor.m_CifgEnabled)
397 {
398 optInputToInputWeights =
399 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
400 inputToInputWeights = &optInputToInputWeights;
401
402 optRecurrentToInputWeights =
403 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
404 recurrentToInputWeights = &optRecurrentToInputWeights;
405 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
406 {
407 optCellToInputWeights =
408 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
409 cellToInputWeights = &optCellToInputWeights;
410 }
411 optInputGateBias =
412 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
413 inputGateBias = &optInputGateBias;
414 }
415
416 if(descriptor.m_ProjectionEnabled)
417 {
418 optProjectionWeights =
419 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
420 projectionWeights = &optProjectionWeights;
421 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
422 {
423 optProjectionBias =
424 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
425 projectionBias = &optProjectionBias;
426 }
427 }
428
429 if(descriptor.m_PeepholeEnabled)
430 {
431 optCellToForgetWeights =
432 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
433 cellToForgetWeights = &optCellToForgetWeights;
434 optCellToOutputWeights =
435 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
436 cellToOutputWeights = &optCellToOutputWeights;
437 }
438
David Beck33f0ae02018-10-18 15:13:56 +0100439 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100440 input,
441 outputStateIn,
442 cellStateIn,
443 scratchBuffer,
444 outputStateOut,
445 cellStateOut,
446 output,
447 descriptor,
448 inputToForgetWeights,
449 inputToCellWeights,
450 inputToOutputWeights,
451 recurrentToForgetWeights,
452 recurrentToCellWeights,
453 recurrentToOutputWeights,
454 forgetGateBias,
455 cellBias,
456 outputGateBias,
457 inputToInputWeights,
458 recurrentToInputWeights,
459 cellToInputWeights,
460 inputGateBias,
461 projectionWeights,
462 projectionBias,
463 cellToForgetWeights,
464 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100465 reason);
telsoa014fcda012018-03-09 14:13:49 +0000466 break;
467 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000468 case LayerType::Maximum:
469 {
470 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
471 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
472 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
473
474 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
475 OverrideDataType(input1, dataType),
476 OverrideDataType(output, dataType),
477 reason);
478 break;
479 }
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000480 case LayerType::MemCopy:
481 {
482 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
483 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
484
485 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
486 OverrideDataType(output, dataType),
487 reason);
488 break;
489 }
telsoa014fcda012018-03-09 14:13:49 +0000490 case LayerType::Merger:
491 {
492 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
493
telsoa01c577f2c2018-08-31 09:22:23 +0100494 // Get vector of all inputs.
495 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000496 {
telsoa01c577f2c2018-08-31 09:22:23 +0100497 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000498 };
telsoa01c577f2c2018-08-31 09:22:23 +0100499 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
500 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
501 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000502
telsoa01c577f2c2018-08-31 09:22:23 +0100503 auto getTensorInfoPtr = [](const TensorInfo& info)
504 {
505 return &info;
506 };
507 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
508 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
509 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000510
Nikhil Raj8599a412018-11-19 14:51:07 +0000511 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
512
513 result = layerSupportObject->IsMergerSupported(inputPtrs, output, cLayer->GetParameters(), reason);
telsoa014fcda012018-03-09 14:13:49 +0000514 break;
515 }
516 case LayerType::Multiplication:
517 {
518 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
519 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100520 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100521 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100522 OverrideDataType(input0, dataType),
523 OverrideDataType(input1, dataType),
524 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100525 reason);
telsoa014fcda012018-03-09 14:13:49 +0000526 break;
527 }
528 case LayerType::Normalization:
529 {
530 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
531 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
532 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100533 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
534 OverrideDataType(output, dataType),
535 cLayer->GetParameters(),
536 reason);
telsoa014fcda012018-03-09 14:13:49 +0000537 break;
538 }
539 case LayerType::Output:
540 {
541 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100542 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000543 break;
544 }
545 case LayerType::Permute:
546 {
547 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
548 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
549 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100550 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
551 OverrideDataType(output, dataType),
552 cLayer->GetParameters(),
553 reason);
telsoa014fcda012018-03-09 14:13:49 +0000554 break;
555 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100556 case LayerType::Pad:
557 {
558 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
559 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
560 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100561 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100562 OverrideDataType(input, dataType),
563 OverrideDataType(output, dataType),
564 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100565 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100566 break;
567 }
telsoa014fcda012018-03-09 14:13:49 +0000568 case LayerType::Pooling2d:
569 {
570 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
571 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
572 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100573 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
574 OverrideDataType(output, dataType),
575 cLayer->GetParameters(),
576 reason);
telsoa014fcda012018-03-09 14:13:49 +0000577 break;
578 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000579 case LayerType::PreCompiled:
580 {
581 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
582 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
583 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
584 cLayer->GetParameters(),
585 reason);
586 break;
587 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100588 case LayerType::Division:
589 {
590 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
591 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
592 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100593 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100594 OverrideDataType(input0, dataType),
595 OverrideDataType(input1, dataType),
596 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100597 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100598 break;
599 }
telsoa014fcda012018-03-09 14:13:49 +0000600 case LayerType::Reshape:
601 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000602 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000603 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000604 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
605 cLayer->GetParameters(),
606 reason);
telsoa014fcda012018-03-09 14:13:49 +0000607 break;
608 }
609 case LayerType::ResizeBilinear:
610 {
611 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100612 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000613 break;
614 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000615 case LayerType::Rsqrt:
616 {
617 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
618 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
619 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
620 OverrideDataType(output, dataType),
621 reason);
622 break;
623 }
telsoa014fcda012018-03-09 14:13:49 +0000624 case LayerType::Softmax:
625 {
626 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
627 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100628 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100629 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
630 OverrideDataType(output, dataType),
631 cLayer->GetParameters(),
632 reason);
telsoa014fcda012018-03-09 14:13:49 +0000633 break;
634 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000635 case LayerType::SpaceToBatchNd:
636 {
637 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
638 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
639 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
640 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
641 OverrideDataType(output, dataType),
642 cLayer->GetParameters(),
643 reason);
644 break;
645 }
telsoa014fcda012018-03-09 14:13:49 +0000646 case LayerType::Splitter:
647 {
648 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
649 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100650 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
651 cLayer->GetParameters(),
652 reason);
telsoa014fcda012018-03-09 14:13:49 +0000653 break;
654 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000655 case LayerType::StridedSlice:
656 {
657 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
658 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
659 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
660 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
661 OverrideDataType(output, dataType),
662 cLayer->GetParameters(),
663 reason);
664 break;
665 }
David Beckc2044fe2018-09-05 15:00:38 +0100666 case LayerType::Subtraction:
667 {
668 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
669 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
670 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100671 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100672 OverrideDataType(input0, dataType),
673 OverrideDataType(input1, dataType),
674 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100675 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100676 break;
677 }
narpra0132b90462018-09-13 11:07:48 +0100678 case LayerType::Mean:
679 {
680 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
681 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
682 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100683 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100684 OverrideDataType(input, dataType),
685 OverrideDataType(output, dataType),
686 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100687 reason);
narpra0132b90462018-09-13 11:07:48 +0100688 break;
689 }
kevmay0190539692018-11-29 08:40:19 +0000690 case LayerType::Minimum:
691 {
692 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
693 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
694 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
695 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
696 OverrideDataType(input1, dataType),
697 OverrideDataType(output, dataType),
698 reason);
699 break;
700 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000701 case LayerType::Greater:
702 {
703 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
704 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
705 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
706 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
707 OverrideDataType(input1, dataType),
708 OverrideDataType(output, dataType),
709 reason);
710 break;
711 }
telsoa014fcda012018-03-09 14:13:49 +0000712 default:
713 {
714 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100715 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000716 result = false;
717 break;
718 }
719 }
telsoa014fcda012018-03-09 14:13:49 +0000720 return result;
721}
722
David Beckdcb751f2018-10-03 11:42:42 +0100723bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100724 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100725 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000726{
David Beckdcb751f2018-10-03 11:42:42 +0100727 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100728 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000729}
730
surmeh013537c2c2018-05-18 16:31:43 +0100731}