blob: 209ba6a4ed9aa8cef529c60d88edb792baf38f48 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
7
8#include <Layer.hpp>
9#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010010
David Beckb4540be2018-09-24 13:18:27 +010011#include <armnn/Types.hpp>
12#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000013#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
David Beck111b5d92018-11-12 14:59:37 +000015#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/IBackendInternal.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/iterator/transform_iterator.hpp>
21
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000022#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000023#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024
telsoa014fcda012018-03-09 14:13:49 +000025namespace armnn
26{
27
telsoa01c577f2c2018-08-31 09:22:23 +010028namespace
29{
telsoa01c577f2c2018-08-31 09:22:23 +010030
David Beck29c75de2018-10-23 13:35:58 +010031const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
32{
33 if (!type)
34 {
35 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010036 }
37
David Beck29c75de2018-10-23 13:35:58 +010038 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010039}
40
David Beck29c75de2018-10-23 13:35:58 +010041Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
42{
43 if (!weightsType)
44 {
45 return weightsType;
46 }
47
48 switch(weightsType.value())
49 {
50 case DataType::Float16:
51 case DataType::Float32:
52 return weightsType;
53 case DataType::QuantisedAsymm8:
54 return DataType::Signed32;
55 default:
56 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
57 }
58 return EmptyOptional();
59}
60
61} // anonymous namespace
62
David Beck33f0ae02018-10-18 15:13:56 +010063bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010064 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010065 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010066 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000067{
David Beck33f0ae02018-10-18 15:13:56 +010068 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000069 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010070 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
71
David Beck111b5d92018-11-12 14:59:37 +000072 auto const& backendRegistry = BackendRegistryInstance();
73 if (!backendRegistry.IsBackendRegistered(backendId))
74 {
75 std::stringstream ss;
76 ss << connectableLayer.GetName() << " is not supported on " << backendId
77 << " because this backend is not registered.";
78
79 outReasonIfUnsupported = ss.str();
80 return false;
81 }
82
83 auto backendFactory = backendRegistry.GetFactory(backendId);
84 auto backendObject = backendFactory();
85 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010086
telsoa014fcda012018-03-09 14:13:49 +000087 switch(layer.GetType())
88 {
89 case LayerType::Activation:
90 {
91 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
92 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010093 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010094 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010095 OverrideDataType(input, dataType),
96 OverrideDataType(output, dataType),
97 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010098 reason);
telsoa014fcda012018-03-09 14:13:49 +000099 break;
100 }
101 case LayerType::Addition:
102 {
103 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
104 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100106 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100107 OverrideDataType(input0, dataType),
108 OverrideDataType(input1, dataType),
109 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100110 reason);
telsoa014fcda012018-03-09 14:13:49 +0000111 break;
112 }
113 case LayerType::BatchNormalization:
114 {
115 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100122 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
137 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
138
139 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
143 break;
144 }
telsoa014fcda012018-03-09 14:13:49 +0000145 case LayerType::Constant:
146 {
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100148 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100149 break;
150 }
151 case LayerType::ConvertFp16ToFp32:
152 {
153 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
154 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100155 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100156 break;
157 }
158 case LayerType::ConvertFp32ToFp16:
159 {
160 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100162 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000163 break;
164 }
165 case LayerType::Convolution2d:
166 {
167 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100168
169 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
170 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100171 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100172 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
173
arovir01a6824102018-08-28 17:40:45 +0100174 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100175
arovir01a6824102018-08-28 17:40:45 +0100176 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100177 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100178 if (descriptor.m_BiasEnabled)
179 {
David Beck5eec11d2018-10-04 15:43:17 +0100180 biases =
181 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100182 }
183
David Beck33f0ae02018-10-18 15:13:56 +0100184 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100185 input,
186 output,
187 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100188 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100189 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100190 reason);
telsoa014fcda012018-03-09 14:13:49 +0000191 break;
192 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000193 case LayerType::Debug:
194 {
Nattapat Chaimanowongac5aa1f2018-12-05 15:17:18 +0000195 auto cLayer = boost::polymorphic_downcast<const DebugLayer*>(&layer);
196 const DebugDescriptor& descriptor = cLayer->GetParameters();
197
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000198 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
199 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
200
201 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
202 OverrideDataType(output, dataType),
Nattapat Chaimanowongac5aa1f2018-12-05 15:17:18 +0000203 descriptor,
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000204 reason);
205 break;
206 }
telsoa014fcda012018-03-09 14:13:49 +0000207 case LayerType::DepthwiseConvolution2d:
208 {
209 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100210 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
211 dataType);
212 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
213 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
214
telsoa01c577f2c2018-08-31 09:22:23 +0100215 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100216
217 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100218 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100219 if (descriptor.m_BiasEnabled)
220 {
David Beck5eec11d2018-10-04 15:43:17 +0100221 biases =
222 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100223 }
telsoa01c577f2c2018-08-31 09:22:23 +0100224
David Beck33f0ae02018-10-18 15:13:56 +0100225 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100226 input,
227 output,
228 descriptor,
229 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100230 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100231 reason);
telsoa014fcda012018-03-09 14:13:49 +0000232 break;
233 }
FrancisMurtagh20995952018-12-17 12:11:36 +0000234 case LayerType::Equal:
235 {
236 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
237 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
238 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
239 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
240 OverrideDataType(input1, dataType),
241 OverrideDataType(output, dataType),
242 reason);
243 break;
244 }
telsoa014fcda012018-03-09 14:13:49 +0000245 case LayerType::FakeQuantization:
246 {
247 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
248 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100249 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
250 cLayer->GetParameters(),
251 reason);
telsoa014fcda012018-03-09 14:13:49 +0000252 break;
253 }
254 case LayerType::Floor:
255 {
256 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
257 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100258 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
259 OverrideDataType(output, dataType),
260 reason);
telsoa014fcda012018-03-09 14:13:49 +0000261 break;
262 }
263 case LayerType::FullyConnected:
264 {
265 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
266 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100267 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
268 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
269
270 TensorInfo biasInfo;
271 const TensorInfo * biasInfoPtr = nullptr;
272 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
273 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
274 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
275
276 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
277 if (descriptor.m_BiasEnabled)
278 {
279 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
280 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
281 biasInfoPtr = &biasInfo;
282 }
283 else
284 {
285 // If biases are not enabled pass a dummy tensorinfo for the validation
286 switch(input.GetDataType())
287 {
288 case DataType::Float16:
289 {
290 biasInfoPtr = &dummyFloat16Bias;
291 break;
292 }
293 case DataType::Float32:
294 {
295 biasInfoPtr = &dummyFloat32Bias;
296 break;
297 }
298 case DataType::QuantisedAsymm8:
299 {
300 biasInfoPtr = &dummyQA8Bias;
301 break;
302 }
303 default:
304 {
305 BOOST_ASSERT_MSG(false, "Unexpected bias type");
306 }
307 }
308 }
309
David Beck33f0ae02018-10-18 15:13:56 +0100310 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100311 OverrideDataType(input, dataType),
312 OverrideDataType(output, dataType),
313 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
314 *biasInfoPtr,
315 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100316 reason);
telsoa014fcda012018-03-09 14:13:49 +0000317 break;
318 }
319 case LayerType::Input:
320 {
321 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100322 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000323 break;
324 }
325 case LayerType::L2Normalization:
326 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100327 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
328 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
329
telsoa014fcda012018-03-09 14:13:49 +0000330 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100331 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100332
David Beck33f0ae02018-10-18 15:13:56 +0100333 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100334 OverrideDataType(input, dataType),
335 OverrideDataType(output, dataType),
336 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100337 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100338 break;
339 }
340 case LayerType::Lstm:
341 {
342 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
343 const LstmDescriptor& descriptor = cLayer->GetParameters();
344
345 // All inputs.
346 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
347 dataType);
348 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
349 dataType);
350 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
351 dataType);
352 // All outputs
353 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
354 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
355 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
356 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
357
358 // Basic parameters
359 const TensorInfo& inputToForgetWeights
360 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
361 const TensorInfo& inputToCellWeights
362 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
363 const TensorInfo& inputToOutputWeights
364 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
365 const TensorInfo& recurrentToForgetWeights
366 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
367 const TensorInfo& recurrentToCellWeights
368 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
369 const TensorInfo& recurrentToOutputWeights
370 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
371 const TensorInfo& forgetGateBias
372 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
373 const TensorInfo& cellBias
374 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
375 const TensorInfo& outputGateBias
376 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
377
378 // Optional parameters
379 const TensorInfo* inputToInputWeights = nullptr;
380 const TensorInfo* recurrentToInputWeights = nullptr;
381 const TensorInfo* cellToInputWeights = nullptr;
382 const TensorInfo* inputGateBias = nullptr;
383 const TensorInfo* projectionWeights = nullptr;
384 const TensorInfo* projectionBias = nullptr;
385 const TensorInfo* cellToForgetWeights = nullptr;
386 const TensorInfo* cellToOutputWeights = nullptr;
387
388 TensorInfo optInputToInputWeights;
389 TensorInfo optRecurrentToInputWeights;
390 TensorInfo optCellToInputWeights;
391 TensorInfo optInputGateBias;
392 TensorInfo optProjectionWeights;
393 TensorInfo optProjectionBias;
394 TensorInfo optCellToForgetWeights;
395 TensorInfo optCellToOutputWeights;
396
397 if(!descriptor.m_CifgEnabled)
398 {
399 optInputToInputWeights =
400 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
401 inputToInputWeights = &optInputToInputWeights;
402
403 optRecurrentToInputWeights =
404 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
405 recurrentToInputWeights = &optRecurrentToInputWeights;
406 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
407 {
408 optCellToInputWeights =
409 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
410 cellToInputWeights = &optCellToInputWeights;
411 }
412 optInputGateBias =
413 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
414 inputGateBias = &optInputGateBias;
415 }
416
417 if(descriptor.m_ProjectionEnabled)
418 {
419 optProjectionWeights =
420 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
421 projectionWeights = &optProjectionWeights;
422 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
423 {
424 optProjectionBias =
425 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
426 projectionBias = &optProjectionBias;
427 }
428 }
429
430 if(descriptor.m_PeepholeEnabled)
431 {
432 optCellToForgetWeights =
433 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
434 cellToForgetWeights = &optCellToForgetWeights;
435 optCellToOutputWeights =
436 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
437 cellToOutputWeights = &optCellToOutputWeights;
438 }
439
David Beck33f0ae02018-10-18 15:13:56 +0100440 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100441 input,
442 outputStateIn,
443 cellStateIn,
444 scratchBuffer,
445 outputStateOut,
446 cellStateOut,
447 output,
448 descriptor,
449 inputToForgetWeights,
450 inputToCellWeights,
451 inputToOutputWeights,
452 recurrentToForgetWeights,
453 recurrentToCellWeights,
454 recurrentToOutputWeights,
455 forgetGateBias,
456 cellBias,
457 outputGateBias,
458 inputToInputWeights,
459 recurrentToInputWeights,
460 cellToInputWeights,
461 inputGateBias,
462 projectionWeights,
463 projectionBias,
464 cellToForgetWeights,
465 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100466 reason);
telsoa014fcda012018-03-09 14:13:49 +0000467 break;
468 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000469 case LayerType::Maximum:
470 {
471 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
472 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
473 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
474
475 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
476 OverrideDataType(input1, dataType),
477 OverrideDataType(output, dataType),
478 reason);
479 break;
480 }
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000481 case LayerType::MemCopy:
482 {
483 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
484 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
485
486 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
487 OverrideDataType(output, dataType),
488 reason);
489 break;
490 }
telsoa014fcda012018-03-09 14:13:49 +0000491 case LayerType::Merger:
492 {
493 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
494
telsoa01c577f2c2018-08-31 09:22:23 +0100495 // Get vector of all inputs.
496 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000497 {
telsoa01c577f2c2018-08-31 09:22:23 +0100498 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000499 };
telsoa01c577f2c2018-08-31 09:22:23 +0100500 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
501 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
502 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000503
telsoa01c577f2c2018-08-31 09:22:23 +0100504 auto getTensorInfoPtr = [](const TensorInfo& info)
505 {
506 return &info;
507 };
508 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
509 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
510 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000511
Nikhil Raj8599a412018-11-19 14:51:07 +0000512 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
513
514 result = layerSupportObject->IsMergerSupported(inputPtrs, output, cLayer->GetParameters(), reason);
telsoa014fcda012018-03-09 14:13:49 +0000515 break;
516 }
517 case LayerType::Multiplication:
518 {
519 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
520 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100521 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100522 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100523 OverrideDataType(input0, dataType),
524 OverrideDataType(input1, dataType),
525 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100526 reason);
telsoa014fcda012018-03-09 14:13:49 +0000527 break;
528 }
529 case LayerType::Normalization:
530 {
531 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
532 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
533 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100534 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
535 OverrideDataType(output, dataType),
536 cLayer->GetParameters(),
537 reason);
telsoa014fcda012018-03-09 14:13:49 +0000538 break;
539 }
540 case LayerType::Output:
541 {
542 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100543 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000544 break;
545 }
546 case LayerType::Permute:
547 {
548 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
549 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
550 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100551 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
552 OverrideDataType(output, dataType),
553 cLayer->GetParameters(),
554 reason);
telsoa014fcda012018-03-09 14:13:49 +0000555 break;
556 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100557 case LayerType::Pad:
558 {
559 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
560 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
561 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100562 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100563 OverrideDataType(input, dataType),
564 OverrideDataType(output, dataType),
565 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100566 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100567 break;
568 }
telsoa014fcda012018-03-09 14:13:49 +0000569 case LayerType::Pooling2d:
570 {
571 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
572 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
573 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100574 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
575 OverrideDataType(output, dataType),
576 cLayer->GetParameters(),
577 reason);
telsoa014fcda012018-03-09 14:13:49 +0000578 break;
579 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100580 case LayerType::Division:
581 {
582 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
583 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
584 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100585 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100586 OverrideDataType(input0, dataType),
587 OverrideDataType(input1, dataType),
588 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100589 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100590 break;
591 }
telsoa014fcda012018-03-09 14:13:49 +0000592 case LayerType::Reshape:
593 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000594 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000595 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000596 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
597 cLayer->GetParameters(),
598 reason);
telsoa014fcda012018-03-09 14:13:49 +0000599 break;
600 }
601 case LayerType::ResizeBilinear:
602 {
603 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100604 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000605 break;
606 }
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000607 case LayerType::Rsqrt:
608 {
609 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
610 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
611 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
612 OverrideDataType(output, dataType),
613 reason);
614 break;
615 }
telsoa014fcda012018-03-09 14:13:49 +0000616 case LayerType::Softmax:
617 {
618 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
619 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100620 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100621 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
622 OverrideDataType(output, dataType),
623 cLayer->GetParameters(),
624 reason);
telsoa014fcda012018-03-09 14:13:49 +0000625 break;
626 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000627 case LayerType::SpaceToBatchNd:
628 {
629 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
630 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
631 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
632 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
633 OverrideDataType(output, dataType),
634 cLayer->GetParameters(),
635 reason);
636 break;
637 }
telsoa014fcda012018-03-09 14:13:49 +0000638 case LayerType::Splitter:
639 {
640 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
641 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100642 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
643 cLayer->GetParameters(),
644 reason);
telsoa014fcda012018-03-09 14:13:49 +0000645 break;
646 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000647 case LayerType::StridedSlice:
648 {
649 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
650 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
651 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
652 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
653 OverrideDataType(output, dataType),
654 cLayer->GetParameters(),
655 reason);
656 break;
657 }
David Beckc2044fe2018-09-05 15:00:38 +0100658 case LayerType::Subtraction:
659 {
660 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
661 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
662 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100663 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100664 OverrideDataType(input0, dataType),
665 OverrideDataType(input1, dataType),
666 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100667 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100668 break;
669 }
narpra0132b90462018-09-13 11:07:48 +0100670 case LayerType::Mean:
671 {
672 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
673 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
674 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100675 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100676 OverrideDataType(input, dataType),
677 OverrideDataType(output, dataType),
678 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100679 reason);
narpra0132b90462018-09-13 11:07:48 +0100680 break;
681 }
kevmay0190539692018-11-29 08:40:19 +0000682 case LayerType::Minimum:
683 {
684 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
685 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
686 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
687 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
688 OverrideDataType(input1, dataType),
689 OverrideDataType(output, dataType),
690 reason);
691 break;
692 }
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000693 case LayerType::Greater:
694 {
695 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
696 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
697 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
698 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
699 OverrideDataType(input1, dataType),
700 OverrideDataType(output, dataType),
701 reason);
702 break;
703 }
telsoa014fcda012018-03-09 14:13:49 +0000704 default:
705 {
706 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100707 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000708 result = false;
709 break;
710 }
711 }
telsoa014fcda012018-03-09 14:13:49 +0000712 return result;
713}
714
David Beckdcb751f2018-10-03 11:42:42 +0100715bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100716 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100717 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000718{
David Beckdcb751f2018-10-03 11:42:42 +0100719 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100720 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000721}
722
surmeh013537c2c2018-05-18 16:31:43 +0100723}