blob: bb63b336e92a699e1290cb0d2176c6e1d7e2c9e4 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
7
8#include <Layer.hpp>
9#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010010
David Beckb4540be2018-09-24 13:18:27 +010011#include <armnn/Types.hpp>
12#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000013#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
David Beck111b5d92018-11-12 14:59:37 +000015#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/IBackendInternal.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/iterator/transform_iterator.hpp>
21
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000022#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000023#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024
telsoa014fcda012018-03-09 14:13:49 +000025namespace armnn
26{
27
telsoa01c577f2c2018-08-31 09:22:23 +010028namespace
29{
telsoa01c577f2c2018-08-31 09:22:23 +010030
David Beck29c75de2018-10-23 13:35:58 +010031const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
32{
33 if (!type)
34 {
35 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010036 }
37
David Beck29c75de2018-10-23 13:35:58 +010038 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010039}
40
David Beck29c75de2018-10-23 13:35:58 +010041Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
42{
43 if (!weightsType)
44 {
45 return weightsType;
46 }
47
48 switch(weightsType.value())
49 {
50 case DataType::Float16:
51 case DataType::Float32:
52 return weightsType;
53 case DataType::QuantisedAsymm8:
54 return DataType::Signed32;
55 default:
56 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
57 }
58 return EmptyOptional();
59}
60
61} // anonymous namespace
62
David Beck33f0ae02018-10-18 15:13:56 +010063bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010064 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010065 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010066 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000067{
David Beck33f0ae02018-10-18 15:13:56 +010068 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000069 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010070 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
71
David Beck111b5d92018-11-12 14:59:37 +000072 auto const& backendRegistry = BackendRegistryInstance();
73 if (!backendRegistry.IsBackendRegistered(backendId))
74 {
75 std::stringstream ss;
76 ss << connectableLayer.GetName() << " is not supported on " << backendId
77 << " because this backend is not registered.";
78
79 outReasonIfUnsupported = ss.str();
80 return false;
81 }
82
83 auto backendFactory = backendRegistry.GetFactory(backendId);
84 auto backendObject = backendFactory();
85 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010086
telsoa014fcda012018-03-09 14:13:49 +000087 switch(layer.GetType())
88 {
89 case LayerType::Activation:
90 {
91 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
92 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010093 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010094 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010095 OverrideDataType(input, dataType),
96 OverrideDataType(output, dataType),
97 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010098 reason);
telsoa014fcda012018-03-09 14:13:49 +000099 break;
100 }
101 case LayerType::Addition:
102 {
103 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
104 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100106 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100107 OverrideDataType(input0, dataType),
108 OverrideDataType(input1, dataType),
109 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100110 reason);
telsoa014fcda012018-03-09 14:13:49 +0000111 break;
112 }
113 case LayerType::BatchNormalization:
114 {
115 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100122 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
137 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
138
139 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
143 break;
144 }
telsoa014fcda012018-03-09 14:13:49 +0000145 case LayerType::Constant:
146 {
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100148 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100149 break;
150 }
151 case LayerType::ConvertFp16ToFp32:
152 {
153 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
154 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100155 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100156 break;
157 }
158 case LayerType::ConvertFp32ToFp16:
159 {
160 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100162 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000163 break;
164 }
165 case LayerType::Convolution2d:
166 {
167 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100168
169 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
170 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100171 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100172 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
173
arovir01a6824102018-08-28 17:40:45 +0100174 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100175
arovir01a6824102018-08-28 17:40:45 +0100176 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100177 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100178 if (descriptor.m_BiasEnabled)
179 {
David Beck5eec11d2018-10-04 15:43:17 +0100180 biases =
181 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100182 }
183
David Beck33f0ae02018-10-18 15:13:56 +0100184 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100185 input,
186 output,
187 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100188 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100189 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100190 reason);
telsoa014fcda012018-03-09 14:13:49 +0000191 break;
192 }
193 case LayerType::MemCopy:
194 {
telsoa01c577f2c2018-08-31 09:22:23 +0100195 // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
196 // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
David Beck33f0ae02018-10-18 15:13:56 +0100197 result = backendId == Compute::CpuRef || backendId == Compute::Undefined
198 || backendId == Compute::CpuAcc || backendId == Compute::GpuAcc;
199 reason.value() = "Unsupported backend type";
telsoa014fcda012018-03-09 14:13:49 +0000200 break;
201 }
202 case LayerType::DepthwiseConvolution2d:
203 {
204 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100205 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
206 dataType);
207 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
208 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
209
telsoa01c577f2c2018-08-31 09:22:23 +0100210 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100211
212 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100213 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100214 if (descriptor.m_BiasEnabled)
215 {
David Beck5eec11d2018-10-04 15:43:17 +0100216 biases =
217 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100218 }
telsoa01c577f2c2018-08-31 09:22:23 +0100219
David Beck33f0ae02018-10-18 15:13:56 +0100220 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100221 input,
222 output,
223 descriptor,
224 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100225 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100226 reason);
telsoa014fcda012018-03-09 14:13:49 +0000227 break;
228 }
229 case LayerType::FakeQuantization:
230 {
231 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
232 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100233 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
234 cLayer->GetParameters(),
235 reason);
telsoa014fcda012018-03-09 14:13:49 +0000236 break;
237 }
238 case LayerType::Floor:
239 {
240 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
241 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100242 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
243 OverrideDataType(output, dataType),
244 reason);
telsoa014fcda012018-03-09 14:13:49 +0000245 break;
246 }
247 case LayerType::FullyConnected:
248 {
249 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
250 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100251 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
252 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
253
254 TensorInfo biasInfo;
255 const TensorInfo * biasInfoPtr = nullptr;
256 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
257 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
258 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
259
260 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
261 if (descriptor.m_BiasEnabled)
262 {
263 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
264 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
265 biasInfoPtr = &biasInfo;
266 }
267 else
268 {
269 // If biases are not enabled pass a dummy tensorinfo for the validation
270 switch(input.GetDataType())
271 {
272 case DataType::Float16:
273 {
274 biasInfoPtr = &dummyFloat16Bias;
275 break;
276 }
277 case DataType::Float32:
278 {
279 biasInfoPtr = &dummyFloat32Bias;
280 break;
281 }
282 case DataType::QuantisedAsymm8:
283 {
284 biasInfoPtr = &dummyQA8Bias;
285 break;
286 }
287 default:
288 {
289 BOOST_ASSERT_MSG(false, "Unexpected bias type");
290 }
291 }
292 }
293
David Beck33f0ae02018-10-18 15:13:56 +0100294 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100295 OverrideDataType(input, dataType),
296 OverrideDataType(output, dataType),
297 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
298 *biasInfoPtr,
299 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100300 reason);
telsoa014fcda012018-03-09 14:13:49 +0000301 break;
302 }
303 case LayerType::Input:
304 {
305 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100306 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000307 break;
308 }
309 case LayerType::L2Normalization:
310 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100311 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
312 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
313
telsoa014fcda012018-03-09 14:13:49 +0000314 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100315 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100316
David Beck33f0ae02018-10-18 15:13:56 +0100317 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100318 OverrideDataType(input, dataType),
319 OverrideDataType(output, dataType),
320 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100321 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100322 break;
323 }
324 case LayerType::Lstm:
325 {
326 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
327 const LstmDescriptor& descriptor = cLayer->GetParameters();
328
329 // All inputs.
330 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
331 dataType);
332 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
333 dataType);
334 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
335 dataType);
336 // All outputs
337 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
338 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
339 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
340 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
341
342 // Basic parameters
343 const TensorInfo& inputToForgetWeights
344 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
345 const TensorInfo& inputToCellWeights
346 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
347 const TensorInfo& inputToOutputWeights
348 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
349 const TensorInfo& recurrentToForgetWeights
350 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
351 const TensorInfo& recurrentToCellWeights
352 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
353 const TensorInfo& recurrentToOutputWeights
354 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
355 const TensorInfo& forgetGateBias
356 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
357 const TensorInfo& cellBias
358 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
359 const TensorInfo& outputGateBias
360 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
361
362 // Optional parameters
363 const TensorInfo* inputToInputWeights = nullptr;
364 const TensorInfo* recurrentToInputWeights = nullptr;
365 const TensorInfo* cellToInputWeights = nullptr;
366 const TensorInfo* inputGateBias = nullptr;
367 const TensorInfo* projectionWeights = nullptr;
368 const TensorInfo* projectionBias = nullptr;
369 const TensorInfo* cellToForgetWeights = nullptr;
370 const TensorInfo* cellToOutputWeights = nullptr;
371
372 TensorInfo optInputToInputWeights;
373 TensorInfo optRecurrentToInputWeights;
374 TensorInfo optCellToInputWeights;
375 TensorInfo optInputGateBias;
376 TensorInfo optProjectionWeights;
377 TensorInfo optProjectionBias;
378 TensorInfo optCellToForgetWeights;
379 TensorInfo optCellToOutputWeights;
380
381 if(!descriptor.m_CifgEnabled)
382 {
383 optInputToInputWeights =
384 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
385 inputToInputWeights = &optInputToInputWeights;
386
387 optRecurrentToInputWeights =
388 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
389 recurrentToInputWeights = &optRecurrentToInputWeights;
390 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
391 {
392 optCellToInputWeights =
393 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
394 cellToInputWeights = &optCellToInputWeights;
395 }
396 optInputGateBias =
397 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
398 inputGateBias = &optInputGateBias;
399 }
400
401 if(descriptor.m_ProjectionEnabled)
402 {
403 optProjectionWeights =
404 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
405 projectionWeights = &optProjectionWeights;
406 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
407 {
408 optProjectionBias =
409 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
410 projectionBias = &optProjectionBias;
411 }
412 }
413
414 if(descriptor.m_PeepholeEnabled)
415 {
416 optCellToForgetWeights =
417 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
418 cellToForgetWeights = &optCellToForgetWeights;
419 optCellToOutputWeights =
420 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
421 cellToOutputWeights = &optCellToOutputWeights;
422 }
423
David Beck33f0ae02018-10-18 15:13:56 +0100424 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100425 input,
426 outputStateIn,
427 cellStateIn,
428 scratchBuffer,
429 outputStateOut,
430 cellStateOut,
431 output,
432 descriptor,
433 inputToForgetWeights,
434 inputToCellWeights,
435 inputToOutputWeights,
436 recurrentToForgetWeights,
437 recurrentToCellWeights,
438 recurrentToOutputWeights,
439 forgetGateBias,
440 cellBias,
441 outputGateBias,
442 inputToInputWeights,
443 recurrentToInputWeights,
444 cellToInputWeights,
445 inputGateBias,
446 projectionWeights,
447 projectionBias,
448 cellToForgetWeights,
449 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100450 reason);
telsoa014fcda012018-03-09 14:13:49 +0000451 break;
452 }
453 case LayerType::Merger:
454 {
455 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
456
telsoa01c577f2c2018-08-31 09:22:23 +0100457 // Get vector of all inputs.
458 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000459 {
telsoa01c577f2c2018-08-31 09:22:23 +0100460 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000461 };
telsoa01c577f2c2018-08-31 09:22:23 +0100462 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
463 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
464 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000465
telsoa01c577f2c2018-08-31 09:22:23 +0100466 auto getTensorInfoPtr = [](const TensorInfo& info)
467 {
468 return &info;
469 };
470 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
471 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
472 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000473
David Beck33f0ae02018-10-18 15:13:56 +0100474 result = layerSupportObject->IsMergerSupported(inputPtrs, cLayer->GetParameters(), reason);
telsoa014fcda012018-03-09 14:13:49 +0000475 break;
476 }
477 case LayerType::Multiplication:
478 {
479 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
480 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100481 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100482 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100483 OverrideDataType(input0, dataType),
484 OverrideDataType(input1, dataType),
485 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100486 reason);
telsoa014fcda012018-03-09 14:13:49 +0000487 break;
488 }
489 case LayerType::Normalization:
490 {
491 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
492 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
493 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100494 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
495 OverrideDataType(output, dataType),
496 cLayer->GetParameters(),
497 reason);
telsoa014fcda012018-03-09 14:13:49 +0000498 break;
499 }
500 case LayerType::Output:
501 {
502 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100503 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000504 break;
505 }
506 case LayerType::Permute:
507 {
508 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
509 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
510 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100511 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
512 OverrideDataType(output, dataType),
513 cLayer->GetParameters(),
514 reason);
telsoa014fcda012018-03-09 14:13:49 +0000515 break;
516 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100517 case LayerType::Pad:
518 {
519 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
520 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
521 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100522 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100523 OverrideDataType(input, dataType),
524 OverrideDataType(output, dataType),
525 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100526 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100527 break;
528 }
telsoa014fcda012018-03-09 14:13:49 +0000529 case LayerType::Pooling2d:
530 {
531 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
532 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
533 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100534 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
535 OverrideDataType(output, dataType),
536 cLayer->GetParameters(),
537 reason);
telsoa014fcda012018-03-09 14:13:49 +0000538 break;
539 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100540 case LayerType::Division:
541 {
542 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
543 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
544 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100545 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100546 OverrideDataType(input0, dataType),
547 OverrideDataType(input1, dataType),
548 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100549 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100550 break;
551 }
telsoa014fcda012018-03-09 14:13:49 +0000552 case LayerType::Reshape:
553 {
554 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100555 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000556 break;
557 }
558 case LayerType::ResizeBilinear:
559 {
560 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100561 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000562 break;
563 }
564 case LayerType::Softmax:
565 {
566 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
567 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100568 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100569 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
570 OverrideDataType(output, dataType),
571 cLayer->GetParameters(),
572 reason);
telsoa014fcda012018-03-09 14:13:49 +0000573 break;
574 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000575 case LayerType::SpaceToBatchNd:
576 {
577 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
578 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
579 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
580 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
581 OverrideDataType(output, dataType),
582 cLayer->GetParameters(),
583 reason);
584 break;
585 }
telsoa014fcda012018-03-09 14:13:49 +0000586 case LayerType::Splitter:
587 {
588 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
589 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100590 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
591 cLayer->GetParameters(),
592 reason);
telsoa014fcda012018-03-09 14:13:49 +0000593 break;
594 }
David Beckc2044fe2018-09-05 15:00:38 +0100595 case LayerType::Subtraction:
596 {
597 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
598 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
599 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100600 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100601 OverrideDataType(input0, dataType),
602 OverrideDataType(input1, dataType),
603 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100604 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100605 break;
606 }
narpra0132b90462018-09-13 11:07:48 +0100607 case LayerType::Mean:
608 {
609 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
610 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
611 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100612 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100613 OverrideDataType(input, dataType),
614 OverrideDataType(output, dataType),
615 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100616 reason);
narpra0132b90462018-09-13 11:07:48 +0100617 break;
618 }
telsoa014fcda012018-03-09 14:13:49 +0000619 default:
620 {
621 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100622 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000623 result = false;
624 break;
625 }
626 }
telsoa014fcda012018-03-09 14:13:49 +0000627 return result;
628}
629
David Beckdcb751f2018-10-03 11:42:42 +0100630bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100631 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100632 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000633{
David Beckdcb751f2018-10-03 11:42:42 +0100634 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100635 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000636}
637
surmeh013537c2c2018-05-18 16:31:43 +0100638}