blob: fc79018aa58913b50518eac33bf77364213812f3 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
7
8#include <Layer.hpp>
9#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010010
David Beckb4540be2018-09-24 13:18:27 +010011#include <armnn/Types.hpp>
12#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000013#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
David Beck111b5d92018-11-12 14:59:37 +000015#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/IBackendInternal.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/iterator/transform_iterator.hpp>
21
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000022#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000023#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024
telsoa014fcda012018-03-09 14:13:49 +000025namespace armnn
26{
27
telsoa01c577f2c2018-08-31 09:22:23 +010028namespace
29{
telsoa01c577f2c2018-08-31 09:22:23 +010030
David Beck29c75de2018-10-23 13:35:58 +010031const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
32{
33 if (!type)
34 {
35 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010036 }
37
David Beck29c75de2018-10-23 13:35:58 +010038 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010039}
40
David Beck29c75de2018-10-23 13:35:58 +010041Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
42{
43 if (!weightsType)
44 {
45 return weightsType;
46 }
47
48 switch(weightsType.value())
49 {
50 case DataType::Float16:
51 case DataType::Float32:
52 return weightsType;
53 case DataType::QuantisedAsymm8:
54 return DataType::Signed32;
55 default:
56 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
57 }
58 return EmptyOptional();
59}
60
61} // anonymous namespace
62
David Beck33f0ae02018-10-18 15:13:56 +010063bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010064 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010065 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010066 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000067{
David Beck33f0ae02018-10-18 15:13:56 +010068 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000069 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010070 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
71
David Beck111b5d92018-11-12 14:59:37 +000072 auto const& backendRegistry = BackendRegistryInstance();
73 if (!backendRegistry.IsBackendRegistered(backendId))
74 {
75 std::stringstream ss;
76 ss << connectableLayer.GetName() << " is not supported on " << backendId
77 << " because this backend is not registered.";
78
79 outReasonIfUnsupported = ss.str();
80 return false;
81 }
82
83 auto backendFactory = backendRegistry.GetFactory(backendId);
84 auto backendObject = backendFactory();
85 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010086
telsoa014fcda012018-03-09 14:13:49 +000087 switch(layer.GetType())
88 {
89 case LayerType::Activation:
90 {
91 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
92 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010093 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010094 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010095 OverrideDataType(input, dataType),
96 OverrideDataType(output, dataType),
97 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010098 reason);
telsoa014fcda012018-03-09 14:13:49 +000099 break;
100 }
101 case LayerType::Addition:
102 {
103 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
104 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100106 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100107 OverrideDataType(input0, dataType),
108 OverrideDataType(input1, dataType),
109 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100110 reason);
telsoa014fcda012018-03-09 14:13:49 +0000111 break;
112 }
113 case LayerType::BatchNormalization:
114 {
115 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100122 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
137 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
138
139 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
143 break;
144 }
telsoa014fcda012018-03-09 14:13:49 +0000145 case LayerType::Constant:
146 {
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100148 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100149 break;
150 }
151 case LayerType::ConvertFp16ToFp32:
152 {
153 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
154 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100155 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100156 break;
157 }
158 case LayerType::ConvertFp32ToFp16:
159 {
160 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100162 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000163 break;
164 }
165 case LayerType::Convolution2d:
166 {
167 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100168
169 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
170 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100171 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100172 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
173
arovir01a6824102018-08-28 17:40:45 +0100174 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100175
arovir01a6824102018-08-28 17:40:45 +0100176 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100177 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100178 if (descriptor.m_BiasEnabled)
179 {
David Beck5eec11d2018-10-04 15:43:17 +0100180 biases =
181 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100182 }
183
David Beck33f0ae02018-10-18 15:13:56 +0100184 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100185 input,
186 output,
187 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100188 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100189 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100190 reason);
telsoa014fcda012018-03-09 14:13:49 +0000191 break;
192 }
193 case LayerType::MemCopy:
194 {
telsoa01c577f2c2018-08-31 09:22:23 +0100195 // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
196 // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
David Beck33f0ae02018-10-18 15:13:56 +0100197 result = backendId == Compute::CpuRef || backendId == Compute::Undefined
198 || backendId == Compute::CpuAcc || backendId == Compute::GpuAcc;
199 reason.value() = "Unsupported backend type";
telsoa014fcda012018-03-09 14:13:49 +0000200 break;
201 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000202 case LayerType::Debug:
203 {
204 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
205 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
206
207 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
208 OverrideDataType(output, dataType),
209 reason);
210 break;
211 }
telsoa014fcda012018-03-09 14:13:49 +0000212 case LayerType::DepthwiseConvolution2d:
213 {
214 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100215 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
216 dataType);
217 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
218 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
219
telsoa01c577f2c2018-08-31 09:22:23 +0100220 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100221
222 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100223 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100224 if (descriptor.m_BiasEnabled)
225 {
David Beck5eec11d2018-10-04 15:43:17 +0100226 biases =
227 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100228 }
telsoa01c577f2c2018-08-31 09:22:23 +0100229
David Beck33f0ae02018-10-18 15:13:56 +0100230 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100231 input,
232 output,
233 descriptor,
234 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100235 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100236 reason);
telsoa014fcda012018-03-09 14:13:49 +0000237 break;
238 }
239 case LayerType::FakeQuantization:
240 {
241 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
242 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100243 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
244 cLayer->GetParameters(),
245 reason);
telsoa014fcda012018-03-09 14:13:49 +0000246 break;
247 }
248 case LayerType::Floor:
249 {
250 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
251 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100252 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
253 OverrideDataType(output, dataType),
254 reason);
telsoa014fcda012018-03-09 14:13:49 +0000255 break;
256 }
257 case LayerType::FullyConnected:
258 {
259 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
260 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100261 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
262 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
263
264 TensorInfo biasInfo;
265 const TensorInfo * biasInfoPtr = nullptr;
266 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
267 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
268 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
269
270 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
271 if (descriptor.m_BiasEnabled)
272 {
273 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
274 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
275 biasInfoPtr = &biasInfo;
276 }
277 else
278 {
279 // If biases are not enabled pass a dummy tensorinfo for the validation
280 switch(input.GetDataType())
281 {
282 case DataType::Float16:
283 {
284 biasInfoPtr = &dummyFloat16Bias;
285 break;
286 }
287 case DataType::Float32:
288 {
289 biasInfoPtr = &dummyFloat32Bias;
290 break;
291 }
292 case DataType::QuantisedAsymm8:
293 {
294 biasInfoPtr = &dummyQA8Bias;
295 break;
296 }
297 default:
298 {
299 BOOST_ASSERT_MSG(false, "Unexpected bias type");
300 }
301 }
302 }
303
David Beck33f0ae02018-10-18 15:13:56 +0100304 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100305 OverrideDataType(input, dataType),
306 OverrideDataType(output, dataType),
307 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
308 *biasInfoPtr,
309 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100310 reason);
telsoa014fcda012018-03-09 14:13:49 +0000311 break;
312 }
313 case LayerType::Input:
314 {
315 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100316 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000317 break;
318 }
319 case LayerType::L2Normalization:
320 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100321 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
322 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
323
telsoa014fcda012018-03-09 14:13:49 +0000324 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100325 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100326
David Beck33f0ae02018-10-18 15:13:56 +0100327 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100328 OverrideDataType(input, dataType),
329 OverrideDataType(output, dataType),
330 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100331 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100332 break;
333 }
334 case LayerType::Lstm:
335 {
336 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
337 const LstmDescriptor& descriptor = cLayer->GetParameters();
338
339 // All inputs.
340 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
341 dataType);
342 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
343 dataType);
344 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
345 dataType);
346 // All outputs
347 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
348 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
349 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
350 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
351
352 // Basic parameters
353 const TensorInfo& inputToForgetWeights
354 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
355 const TensorInfo& inputToCellWeights
356 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
357 const TensorInfo& inputToOutputWeights
358 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
359 const TensorInfo& recurrentToForgetWeights
360 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
361 const TensorInfo& recurrentToCellWeights
362 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
363 const TensorInfo& recurrentToOutputWeights
364 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
365 const TensorInfo& forgetGateBias
366 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
367 const TensorInfo& cellBias
368 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
369 const TensorInfo& outputGateBias
370 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
371
372 // Optional parameters
373 const TensorInfo* inputToInputWeights = nullptr;
374 const TensorInfo* recurrentToInputWeights = nullptr;
375 const TensorInfo* cellToInputWeights = nullptr;
376 const TensorInfo* inputGateBias = nullptr;
377 const TensorInfo* projectionWeights = nullptr;
378 const TensorInfo* projectionBias = nullptr;
379 const TensorInfo* cellToForgetWeights = nullptr;
380 const TensorInfo* cellToOutputWeights = nullptr;
381
382 TensorInfo optInputToInputWeights;
383 TensorInfo optRecurrentToInputWeights;
384 TensorInfo optCellToInputWeights;
385 TensorInfo optInputGateBias;
386 TensorInfo optProjectionWeights;
387 TensorInfo optProjectionBias;
388 TensorInfo optCellToForgetWeights;
389 TensorInfo optCellToOutputWeights;
390
391 if(!descriptor.m_CifgEnabled)
392 {
393 optInputToInputWeights =
394 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
395 inputToInputWeights = &optInputToInputWeights;
396
397 optRecurrentToInputWeights =
398 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
399 recurrentToInputWeights = &optRecurrentToInputWeights;
400 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
401 {
402 optCellToInputWeights =
403 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
404 cellToInputWeights = &optCellToInputWeights;
405 }
406 optInputGateBias =
407 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
408 inputGateBias = &optInputGateBias;
409 }
410
411 if(descriptor.m_ProjectionEnabled)
412 {
413 optProjectionWeights =
414 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
415 projectionWeights = &optProjectionWeights;
416 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
417 {
418 optProjectionBias =
419 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
420 projectionBias = &optProjectionBias;
421 }
422 }
423
424 if(descriptor.m_PeepholeEnabled)
425 {
426 optCellToForgetWeights =
427 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
428 cellToForgetWeights = &optCellToForgetWeights;
429 optCellToOutputWeights =
430 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
431 cellToOutputWeights = &optCellToOutputWeights;
432 }
433
David Beck33f0ae02018-10-18 15:13:56 +0100434 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100435 input,
436 outputStateIn,
437 cellStateIn,
438 scratchBuffer,
439 outputStateOut,
440 cellStateOut,
441 output,
442 descriptor,
443 inputToForgetWeights,
444 inputToCellWeights,
445 inputToOutputWeights,
446 recurrentToForgetWeights,
447 recurrentToCellWeights,
448 recurrentToOutputWeights,
449 forgetGateBias,
450 cellBias,
451 outputGateBias,
452 inputToInputWeights,
453 recurrentToInputWeights,
454 cellToInputWeights,
455 inputGateBias,
456 projectionWeights,
457 projectionBias,
458 cellToForgetWeights,
459 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100460 reason);
telsoa014fcda012018-03-09 14:13:49 +0000461 break;
462 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000463 case LayerType::Maximum:
464 {
465 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
466 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
467 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
468
469 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
470 OverrideDataType(input1, dataType),
471 OverrideDataType(output, dataType),
472 reason);
473 break;
474 }
telsoa014fcda012018-03-09 14:13:49 +0000475 case LayerType::Merger:
476 {
477 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
478
telsoa01c577f2c2018-08-31 09:22:23 +0100479 // Get vector of all inputs.
480 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000481 {
telsoa01c577f2c2018-08-31 09:22:23 +0100482 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000483 };
telsoa01c577f2c2018-08-31 09:22:23 +0100484 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
485 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
486 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000487
telsoa01c577f2c2018-08-31 09:22:23 +0100488 auto getTensorInfoPtr = [](const TensorInfo& info)
489 {
490 return &info;
491 };
492 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
493 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
494 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000495
Nikhil Raj8599a412018-11-19 14:51:07 +0000496 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
497
498 result = layerSupportObject->IsMergerSupported(inputPtrs, output, cLayer->GetParameters(), reason);
telsoa014fcda012018-03-09 14:13:49 +0000499 break;
500 }
501 case LayerType::Multiplication:
502 {
503 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
504 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100505 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100506 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100507 OverrideDataType(input0, dataType),
508 OverrideDataType(input1, dataType),
509 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100510 reason);
telsoa014fcda012018-03-09 14:13:49 +0000511 break;
512 }
513 case LayerType::Normalization:
514 {
515 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
516 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
517 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100518 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
519 OverrideDataType(output, dataType),
520 cLayer->GetParameters(),
521 reason);
telsoa014fcda012018-03-09 14:13:49 +0000522 break;
523 }
524 case LayerType::Output:
525 {
526 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100527 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000528 break;
529 }
530 case LayerType::Permute:
531 {
532 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
533 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
534 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100535 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
536 OverrideDataType(output, dataType),
537 cLayer->GetParameters(),
538 reason);
telsoa014fcda012018-03-09 14:13:49 +0000539 break;
540 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100541 case LayerType::Pad:
542 {
543 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
544 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
545 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100546 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100547 OverrideDataType(input, dataType),
548 OverrideDataType(output, dataType),
549 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100550 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100551 break;
552 }
telsoa014fcda012018-03-09 14:13:49 +0000553 case LayerType::Pooling2d:
554 {
555 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
556 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
557 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100558 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
559 OverrideDataType(output, dataType),
560 cLayer->GetParameters(),
561 reason);
telsoa014fcda012018-03-09 14:13:49 +0000562 break;
563 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100564 case LayerType::Division:
565 {
566 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
567 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
568 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100569 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100570 OverrideDataType(input0, dataType),
571 OverrideDataType(input1, dataType),
572 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100573 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100574 break;
575 }
telsoa014fcda012018-03-09 14:13:49 +0000576 case LayerType::Reshape:
577 {
578 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100579 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000580 break;
581 }
582 case LayerType::ResizeBilinear:
583 {
584 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100585 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000586 break;
587 }
588 case LayerType::Softmax:
589 {
590 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
591 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100592 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100593 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
594 OverrideDataType(output, dataType),
595 cLayer->GetParameters(),
596 reason);
telsoa014fcda012018-03-09 14:13:49 +0000597 break;
598 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000599 case LayerType::SpaceToBatchNd:
600 {
601 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
602 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
603 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
604 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
605 OverrideDataType(output, dataType),
606 cLayer->GetParameters(),
607 reason);
608 break;
609 }
telsoa014fcda012018-03-09 14:13:49 +0000610 case LayerType::Splitter:
611 {
612 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
613 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100614 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
615 cLayer->GetParameters(),
616 reason);
telsoa014fcda012018-03-09 14:13:49 +0000617 break;
618 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000619 case LayerType::StridedSlice:
620 {
621 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
622 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
623 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
624 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
625 OverrideDataType(output, dataType),
626 cLayer->GetParameters(),
627 reason);
628 break;
629 }
David Beckc2044fe2018-09-05 15:00:38 +0100630 case LayerType::Subtraction:
631 {
632 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
633 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
634 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100635 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100636 OverrideDataType(input0, dataType),
637 OverrideDataType(input1, dataType),
638 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100639 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100640 break;
641 }
narpra0132b90462018-09-13 11:07:48 +0100642 case LayerType::Mean:
643 {
644 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
645 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
646 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100647 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100648 OverrideDataType(input, dataType),
649 OverrideDataType(output, dataType),
650 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100651 reason);
narpra0132b90462018-09-13 11:07:48 +0100652 break;
653 }
kevmay0190539692018-11-29 08:40:19 +0000654 case LayerType::Minimum:
655 {
656 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
657 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
658 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
659 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
660 OverrideDataType(input1, dataType),
661 OverrideDataType(output, dataType),
662 reason);
663 break;
664 }
telsoa014fcda012018-03-09 14:13:49 +0000665 default:
666 {
667 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100668 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000669 result = false;
670 break;
671 }
672 }
telsoa014fcda012018-03-09 14:13:49 +0000673 return result;
674}
675
David Beckdcb751f2018-10-03 11:42:42 +0100676bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100677 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100678 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000679{
David Beckdcb751f2018-10-03 11:42:42 +0100680 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100681 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000682}
683
surmeh013537c2c2018-05-18 16:31:43 +0100684}