blob: 0996a8aaee603fed7643a4dc2f557a8878a6edaa [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
7
8#include <Layer.hpp>
9#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010010
David Beckb4540be2018-09-24 13:18:27 +010011#include <armnn/Types.hpp>
12#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000013#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
David Beck111b5d92018-11-12 14:59:37 +000015#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/IBackendInternal.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/iterator/transform_iterator.hpp>
21
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000022#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000023#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024
telsoa014fcda012018-03-09 14:13:49 +000025namespace armnn
26{
27
telsoa01c577f2c2018-08-31 09:22:23 +010028namespace
29{
telsoa01c577f2c2018-08-31 09:22:23 +010030
David Beck29c75de2018-10-23 13:35:58 +010031const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
32{
33 if (!type)
34 {
35 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010036 }
37
David Beck29c75de2018-10-23 13:35:58 +010038 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010039}
40
David Beck29c75de2018-10-23 13:35:58 +010041Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
42{
43 if (!weightsType)
44 {
45 return weightsType;
46 }
47
48 switch(weightsType.value())
49 {
50 case DataType::Float16:
51 case DataType::Float32:
52 return weightsType;
53 case DataType::QuantisedAsymm8:
54 return DataType::Signed32;
55 default:
56 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
57 }
58 return EmptyOptional();
59}
60
61} // anonymous namespace
62
David Beck33f0ae02018-10-18 15:13:56 +010063bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010064 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010065 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010066 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000067{
David Beck33f0ae02018-10-18 15:13:56 +010068 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000069 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010070 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
71
David Beck111b5d92018-11-12 14:59:37 +000072 auto const& backendRegistry = BackendRegistryInstance();
73 if (!backendRegistry.IsBackendRegistered(backendId))
74 {
75 std::stringstream ss;
76 ss << connectableLayer.GetName() << " is not supported on " << backendId
77 << " because this backend is not registered.";
78
79 outReasonIfUnsupported = ss.str();
80 return false;
81 }
82
83 auto backendFactory = backendRegistry.GetFactory(backendId);
84 auto backendObject = backendFactory();
85 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010086
telsoa014fcda012018-03-09 14:13:49 +000087 switch(layer.GetType())
88 {
89 case LayerType::Activation:
90 {
91 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
92 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010093 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010094 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010095 OverrideDataType(input, dataType),
96 OverrideDataType(output, dataType),
97 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010098 reason);
telsoa014fcda012018-03-09 14:13:49 +000099 break;
100 }
101 case LayerType::Addition:
102 {
103 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
104 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100106 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100107 OverrideDataType(input0, dataType),
108 OverrideDataType(input1, dataType),
109 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100110 reason);
telsoa014fcda012018-03-09 14:13:49 +0000111 break;
112 }
113 case LayerType::BatchNormalization:
114 {
115 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100122 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
137 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
138
139 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
143 break;
144 }
telsoa014fcda012018-03-09 14:13:49 +0000145 case LayerType::Constant:
146 {
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100148 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100149 break;
150 }
151 case LayerType::ConvertFp16ToFp32:
152 {
153 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
154 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100155 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100156 break;
157 }
158 case LayerType::ConvertFp32ToFp16:
159 {
160 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100162 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000163 break;
164 }
165 case LayerType::Convolution2d:
166 {
167 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100168
169 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
170 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100171 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100172 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
173
arovir01a6824102018-08-28 17:40:45 +0100174 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100175
arovir01a6824102018-08-28 17:40:45 +0100176 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100177 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100178 if (descriptor.m_BiasEnabled)
179 {
David Beck5eec11d2018-10-04 15:43:17 +0100180 biases =
181 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100182 }
183
David Beck33f0ae02018-10-18 15:13:56 +0100184 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100185 input,
186 output,
187 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100188 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100189 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100190 reason);
telsoa014fcda012018-03-09 14:13:49 +0000191 break;
192 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000193 case LayerType::Debug:
194 {
195 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
196 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
197
198 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
199 OverrideDataType(output, dataType),
200 reason);
201 break;
202 }
telsoa014fcda012018-03-09 14:13:49 +0000203 case LayerType::DepthwiseConvolution2d:
204 {
205 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100206 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
207 dataType);
208 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
209 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
210
telsoa01c577f2c2018-08-31 09:22:23 +0100211 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100212
213 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100214 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100215 if (descriptor.m_BiasEnabled)
216 {
David Beck5eec11d2018-10-04 15:43:17 +0100217 biases =
218 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100219 }
telsoa01c577f2c2018-08-31 09:22:23 +0100220
David Beck33f0ae02018-10-18 15:13:56 +0100221 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100222 input,
223 output,
224 descriptor,
225 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100226 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100227 reason);
telsoa014fcda012018-03-09 14:13:49 +0000228 break;
229 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000230 case LayerType::DetectionPostProcess:
231 {
232 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
233 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
234 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
235 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
236 result = layerSupportObject->IsDetectionPostProcessSupported(input0,
237 input1,
238 descriptor,
239 reason);
240 break;
241 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000242 case LayerType::Equal:
243 {
244 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
245 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
246 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
247 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
248 OverrideDataType(input1, dataType),
249 OverrideDataType(output, dataType),
250 reason);
251 break;
252 }
telsoa014fcda012018-03-09 14:13:49 +0000253 case LayerType::FakeQuantization:
254 {
255 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
256 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100257 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
258 cLayer->GetParameters(),
259 reason);
telsoa014fcda012018-03-09 14:13:49 +0000260 break;
261 }
262 case LayerType::Floor:
263 {
264 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
265 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100266 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
267 OverrideDataType(output, dataType),
268 reason);
telsoa014fcda012018-03-09 14:13:49 +0000269 break;
270 }
271 case LayerType::FullyConnected:
272 {
273 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
274 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100275 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
276 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
277
278 TensorInfo biasInfo;
279 const TensorInfo * biasInfoPtr = nullptr;
280 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
281 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
282 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
283
284 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
285 if (descriptor.m_BiasEnabled)
286 {
287 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
288 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
289 biasInfoPtr = &biasInfo;
290 }
291 else
292 {
293 // If biases are not enabled pass a dummy tensorinfo for the validation
294 switch(input.GetDataType())
295 {
296 case DataType::Float16:
297 {
298 biasInfoPtr = &dummyFloat16Bias;
299 break;
300 }
301 case DataType::Float32:
302 {
303 biasInfoPtr = &dummyFloat32Bias;
304 break;
305 }
306 case DataType::QuantisedAsymm8:
307 {
308 biasInfoPtr = &dummyQA8Bias;
309 break;
310 }
311 default:
312 {
313 BOOST_ASSERT_MSG(false, "Unexpected bias type");
314 }
315 }
316 }
317
David Beck33f0ae02018-10-18 15:13:56 +0100318 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100319 OverrideDataType(input, dataType),
320 OverrideDataType(output, dataType),
321 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
322 *biasInfoPtr,
323 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100324 reason);
telsoa014fcda012018-03-09 14:13:49 +0000325 break;
326 }
narpra01b89b05f2019-01-16 09:53:09 +0000327 case LayerType::Gather:
328 {
329 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
330 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
331 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
332 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
333 OverrideDataType(input1, dataType),
334 OverrideDataType(output, dataType),
335 reason);
336 break;
337 }
telsoa014fcda012018-03-09 14:13:49 +0000338 case LayerType::Input:
339 {
340 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100341 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000342 break;
343 }
344 case LayerType::L2Normalization:
345 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100346 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
347 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
348
telsoa014fcda012018-03-09 14:13:49 +0000349 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100350 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100351
David Beck33f0ae02018-10-18 15:13:56 +0100352 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100353 OverrideDataType(input, dataType),
354 OverrideDataType(output, dataType),
355 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100356 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100357 break;
358 }
359 case LayerType::Lstm:
360 {
361 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
362 const LstmDescriptor& descriptor = cLayer->GetParameters();
363
364 // All inputs.
365 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
366 dataType);
367 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
368 dataType);
369 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
370 dataType);
371 // All outputs
372 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
373 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
374 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
375 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
376
377 // Basic parameters
378 const TensorInfo& inputToForgetWeights
379 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
380 const TensorInfo& inputToCellWeights
381 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
382 const TensorInfo& inputToOutputWeights
383 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
384 const TensorInfo& recurrentToForgetWeights
385 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
386 const TensorInfo& recurrentToCellWeights
387 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
388 const TensorInfo& recurrentToOutputWeights
389 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
390 const TensorInfo& forgetGateBias
391 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
392 const TensorInfo& cellBias
393 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
394 const TensorInfo& outputGateBias
395 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
396
397 // Optional parameters
398 const TensorInfo* inputToInputWeights = nullptr;
399 const TensorInfo* recurrentToInputWeights = nullptr;
400 const TensorInfo* cellToInputWeights = nullptr;
401 const TensorInfo* inputGateBias = nullptr;
402 const TensorInfo* projectionWeights = nullptr;
403 const TensorInfo* projectionBias = nullptr;
404 const TensorInfo* cellToForgetWeights = nullptr;
405 const TensorInfo* cellToOutputWeights = nullptr;
406
407 TensorInfo optInputToInputWeights;
408 TensorInfo optRecurrentToInputWeights;
409 TensorInfo optCellToInputWeights;
410 TensorInfo optInputGateBias;
411 TensorInfo optProjectionWeights;
412 TensorInfo optProjectionBias;
413 TensorInfo optCellToForgetWeights;
414 TensorInfo optCellToOutputWeights;
415
416 if(!descriptor.m_CifgEnabled)
417 {
418 optInputToInputWeights =
419 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
420 inputToInputWeights = &optInputToInputWeights;
421
422 optRecurrentToInputWeights =
423 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
424 recurrentToInputWeights = &optRecurrentToInputWeights;
425 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
426 {
427 optCellToInputWeights =
428 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
429 cellToInputWeights = &optCellToInputWeights;
430 }
431 optInputGateBias =
432 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
433 inputGateBias = &optInputGateBias;
434 }
435
436 if(descriptor.m_ProjectionEnabled)
437 {
438 optProjectionWeights =
439 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
440 projectionWeights = &optProjectionWeights;
441 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
442 {
443 optProjectionBias =
444 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
445 projectionBias = &optProjectionBias;
446 }
447 }
448
449 if(descriptor.m_PeepholeEnabled)
450 {
451 optCellToForgetWeights =
452 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
453 cellToForgetWeights = &optCellToForgetWeights;
454 optCellToOutputWeights =
455 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
456 cellToOutputWeights = &optCellToOutputWeights;
457 }
458
David Beck33f0ae02018-10-18 15:13:56 +0100459 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100460 input,
461 outputStateIn,
462 cellStateIn,
463 scratchBuffer,
464 outputStateOut,
465 cellStateOut,
466 output,
467 descriptor,
468 inputToForgetWeights,
469 inputToCellWeights,
470 inputToOutputWeights,
471 recurrentToForgetWeights,
472 recurrentToCellWeights,
473 recurrentToOutputWeights,
474 forgetGateBias,
475 cellBias,
476 outputGateBias,
477 inputToInputWeights,
478 recurrentToInputWeights,
479 cellToInputWeights,
480 inputGateBias,
481 projectionWeights,
482 projectionBias,
483 cellToForgetWeights,
484 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100485 reason);
telsoa014fcda012018-03-09 14:13:49 +0000486 break;
487 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000488 case LayerType::Maximum:
489 {
490 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
491 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
492 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
493
494 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
495 OverrideDataType(input1, dataType),
496 OverrideDataType(output, dataType),
497 reason);
498 break;
499 }
narpra01b89b05f2019-01-16 09:53:09 +0000500 case LayerType::MemCopy:
501 {
502 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
503 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000504
narpra01b89b05f2019-01-16 09:53:09 +0000505 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
506 OverrideDataType(output, dataType),
507 reason);
508 break;
509 }
telsoa014fcda012018-03-09 14:13:49 +0000510 case LayerType::Merger:
511 {
512 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
513
telsoa01c577f2c2018-08-31 09:22:23 +0100514 // Get vector of all inputs.
515 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000516 {
telsoa01c577f2c2018-08-31 09:22:23 +0100517 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000518 };
telsoa01c577f2c2018-08-31 09:22:23 +0100519 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
520 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
521 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000522
telsoa01c577f2c2018-08-31 09:22:23 +0100523 auto getTensorInfoPtr = [](const TensorInfo& info)
524 {
525 return &info;
526 };
527 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
528 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
529 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000530
Nikhil Raj8599a412018-11-19 14:51:07 +0000531 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
532
533 result = layerSupportObject->IsMergerSupported(inputPtrs, output, cLayer->GetParameters(), reason);
telsoa014fcda012018-03-09 14:13:49 +0000534 break;
535 }
536 case LayerType::Multiplication:
537 {
538 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
539 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100540 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100541 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100542 OverrideDataType(input0, dataType),
543 OverrideDataType(input1, dataType),
544 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100545 reason);
telsoa014fcda012018-03-09 14:13:49 +0000546 break;
547 }
548 case LayerType::Normalization:
549 {
550 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
551 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
552 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100553 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
554 OverrideDataType(output, dataType),
555 cLayer->GetParameters(),
556 reason);
telsoa014fcda012018-03-09 14:13:49 +0000557 break;
558 }
559 case LayerType::Output:
560 {
561 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100562 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000563 break;
564 }
565 case LayerType::Permute:
566 {
567 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
568 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
569 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100570 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
571 OverrideDataType(output, dataType),
572 cLayer->GetParameters(),
573 reason);
telsoa014fcda012018-03-09 14:13:49 +0000574 break;
575 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100576 case LayerType::Pad:
577 {
578 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
579 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
580 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100581 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100582 OverrideDataType(input, dataType),
583 OverrideDataType(output, dataType),
584 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100585 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100586 break;
587 }
telsoa014fcda012018-03-09 14:13:49 +0000588 case LayerType::Pooling2d:
589 {
590 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
591 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
592 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100593 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
594 OverrideDataType(output, dataType),
595 cLayer->GetParameters(),
596 reason);
telsoa014fcda012018-03-09 14:13:49 +0000597 break;
598 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000599 case LayerType::PreCompiled:
600 {
601 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
602 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
603 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
604 cLayer->GetParameters(),
605 reason);
606 break;
607 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100608 case LayerType::Division:
609 {
610 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
611 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
612 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100613 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100614 OverrideDataType(input0, dataType),
615 OverrideDataType(input1, dataType),
616 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100617 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100618 break;
619 }
telsoa014fcda012018-03-09 14:13:49 +0000620 case LayerType::Reshape:
621 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000622 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000623 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000624 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
625 cLayer->GetParameters(),
626 reason);
telsoa014fcda012018-03-09 14:13:49 +0000627 break;
628 }
629 case LayerType::ResizeBilinear:
630 {
631 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganc625f002018-12-17 11:32:16 +0000632 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
633 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType),
634 OverrideDataType(output, dataType),
635 reason);
telsoa014fcda012018-03-09 14:13:49 +0000636 break;
637 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000638 case LayerType::Rsqrt:
639 {
640 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
641 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
642 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
643 OverrideDataType(output, dataType),
644 reason);
645 break;
646 }
telsoa014fcda012018-03-09 14:13:49 +0000647 case LayerType::Softmax:
648 {
649 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
650 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100651 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100652 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
653 OverrideDataType(output, dataType),
654 cLayer->GetParameters(),
655 reason);
telsoa014fcda012018-03-09 14:13:49 +0000656 break;
657 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000658 case LayerType::SpaceToBatchNd:
659 {
660 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
661 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
662 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
663 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
664 OverrideDataType(output, dataType),
665 cLayer->GetParameters(),
666 reason);
667 break;
668 }
telsoa014fcda012018-03-09 14:13:49 +0000669 case LayerType::Splitter:
670 {
671 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
672 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100673 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
674 cLayer->GetParameters(),
675 reason);
telsoa014fcda012018-03-09 14:13:49 +0000676 break;
677 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000678 case LayerType::StridedSlice:
679 {
680 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
681 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
682 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
683 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
684 OverrideDataType(output, dataType),
685 cLayer->GetParameters(),
686 reason);
687 break;
688 }
David Beckc2044fe2018-09-05 15:00:38 +0100689 case LayerType::Subtraction:
690 {
691 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
692 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
693 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100694 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100695 OverrideDataType(input0, dataType),
696 OverrideDataType(input1, dataType),
697 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100698 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100699 break;
700 }
narpra0132b90462018-09-13 11:07:48 +0100701 case LayerType::Mean:
702 {
703 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
704 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
705 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100706 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100707 OverrideDataType(input, dataType),
708 OverrideDataType(output, dataType),
709 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100710 reason);
narpra0132b90462018-09-13 11:07:48 +0100711 break;
712 }
kevmay0190539692018-11-29 08:40:19 +0000713 case LayerType::Minimum:
714 {
715 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
716 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
717 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
718 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
719 OverrideDataType(input1, dataType),
720 OverrideDataType(output, dataType),
721 reason);
722 break;
723 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000724 case LayerType::Greater:
725 {
726 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
727 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
728 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000729 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
730 OverrideDataType(input1, dataType),
731 OverrideDataType(output, DataType::Boolean),
732 reason);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000733 break;
734 }
telsoa014fcda012018-03-09 14:13:49 +0000735 default:
736 {
737 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100738 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000739 result = false;
740 break;
741 }
742 }
telsoa014fcda012018-03-09 14:13:49 +0000743 return result;
744}
745
David Beckdcb751f2018-10-03 11:42:42 +0100746bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100747 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100748 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000749{
David Beckdcb751f2018-10-03 11:42:42 +0100750 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100751 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000752}
753
surmeh013537c2c2018-05-18 16:31:43 +0100754}