blob: 915d667fed9ac21d31e26281b849b5258104a50d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
7
8#include <Layer.hpp>
9#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010010
David Beckb4540be2018-09-24 13:18:27 +010011#include <armnn/Types.hpp>
12#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000013#include <armnn/ILayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
David Beck111b5d92018-11-12 14:59:37 +000015#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <backendsCommon/WorkloadFactory.hpp>
David Beck111b5d92018-11-12 14:59:37 +000017#include <backendsCommon/IBackendInternal.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/iterator/transform_iterator.hpp>
21
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000022#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000023#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024
telsoa014fcda012018-03-09 14:13:49 +000025namespace armnn
26{
27
telsoa01c577f2c2018-08-31 09:22:23 +010028namespace
29{
telsoa01c577f2c2018-08-31 09:22:23 +010030
David Beck29c75de2018-10-23 13:35:58 +010031const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
32{
33 if (!type)
34 {
35 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010036 }
37
David Beck29c75de2018-10-23 13:35:58 +010038 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010039}
40
David Beck29c75de2018-10-23 13:35:58 +010041Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
42{
43 if (!weightsType)
44 {
45 return weightsType;
46 }
47
48 switch(weightsType.value())
49 {
50 case DataType::Float16:
51 case DataType::Float32:
52 return weightsType;
53 case DataType::QuantisedAsymm8:
54 return DataType::Signed32;
55 default:
56 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
57 }
58 return EmptyOptional();
59}
60
61} // anonymous namespace
62
David Beck33f0ae02018-10-18 15:13:56 +010063bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010064 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010065 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010066 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000067{
David Beck33f0ae02018-10-18 15:13:56 +010068 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000069 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010070 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
71
David Beck111b5d92018-11-12 14:59:37 +000072 auto const& backendRegistry = BackendRegistryInstance();
73 if (!backendRegistry.IsBackendRegistered(backendId))
74 {
75 std::stringstream ss;
76 ss << connectableLayer.GetName() << " is not supported on " << backendId
77 << " because this backend is not registered.";
78
79 outReasonIfUnsupported = ss.str();
80 return false;
81 }
82
83 auto backendFactory = backendRegistry.GetFactory(backendId);
84 auto backendObject = backendFactory();
85 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010086
telsoa014fcda012018-03-09 14:13:49 +000087 switch(layer.GetType())
88 {
89 case LayerType::Activation:
90 {
91 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
92 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010093 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010094 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010095 OverrideDataType(input, dataType),
96 OverrideDataType(output, dataType),
97 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010098 reason);
telsoa014fcda012018-03-09 14:13:49 +000099 break;
100 }
101 case LayerType::Addition:
102 {
103 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
104 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100106 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100107 OverrideDataType(input0, dataType),
108 OverrideDataType(input1, dataType),
109 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100110 reason);
telsoa014fcda012018-03-09 14:13:49 +0000111 break;
112 }
113 case LayerType::BatchNormalization:
114 {
115 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100122 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
137 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
138
139 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
143 break;
144 }
telsoa014fcda012018-03-09 14:13:49 +0000145 case LayerType::Constant:
146 {
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100148 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100149 break;
150 }
151 case LayerType::ConvertFp16ToFp32:
152 {
153 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
154 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100155 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100156 break;
157 }
158 case LayerType::ConvertFp32ToFp16:
159 {
160 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100162 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000163 break;
164 }
165 case LayerType::Convolution2d:
166 {
167 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100168
169 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
170 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100171 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100172 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
173
arovir01a6824102018-08-28 17:40:45 +0100174 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100175
arovir01a6824102018-08-28 17:40:45 +0100176 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100177 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100178 if (descriptor.m_BiasEnabled)
179 {
David Beck5eec11d2018-10-04 15:43:17 +0100180 biases =
181 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100182 }
183
David Beck33f0ae02018-10-18 15:13:56 +0100184 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100185 input,
186 output,
187 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100188 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100189 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100190 reason);
telsoa014fcda012018-03-09 14:13:49 +0000191 break;
192 }
193 case LayerType::MemCopy:
194 {
telsoa01c577f2c2018-08-31 09:22:23 +0100195 // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
196 // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
David Beck33f0ae02018-10-18 15:13:56 +0100197 result = backendId == Compute::CpuRef || backendId == Compute::Undefined
198 || backendId == Compute::CpuAcc || backendId == Compute::GpuAcc;
199 reason.value() = "Unsupported backend type";
telsoa014fcda012018-03-09 14:13:49 +0000200 break;
201 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000202 case LayerType::Debug:
203 {
Nattapat Chaimanowongac5aa1f2018-12-05 15:17:18 +0000204 auto cLayer = boost::polymorphic_downcast<const DebugLayer*>(&layer);
205 const DebugDescriptor& descriptor = cLayer->GetParameters();
206
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000207 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
208 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
209
210 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
211 OverrideDataType(output, dataType),
Nattapat Chaimanowongac5aa1f2018-12-05 15:17:18 +0000212 descriptor,
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000213 reason);
214 break;
215 }
telsoa014fcda012018-03-09 14:13:49 +0000216 case LayerType::DepthwiseConvolution2d:
217 {
218 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100219 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
220 dataType);
221 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
222 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
223
telsoa01c577f2c2018-08-31 09:22:23 +0100224 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100225
226 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100227 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100228 if (descriptor.m_BiasEnabled)
229 {
David Beck5eec11d2018-10-04 15:43:17 +0100230 biases =
231 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100232 }
telsoa01c577f2c2018-08-31 09:22:23 +0100233
David Beck33f0ae02018-10-18 15:13:56 +0100234 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100235 input,
236 output,
237 descriptor,
238 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100239 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100240 reason);
telsoa014fcda012018-03-09 14:13:49 +0000241 break;
242 }
243 case LayerType::FakeQuantization:
244 {
245 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
246 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100247 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
248 cLayer->GetParameters(),
249 reason);
telsoa014fcda012018-03-09 14:13:49 +0000250 break;
251 }
252 case LayerType::Floor:
253 {
254 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
255 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100256 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
257 OverrideDataType(output, dataType),
258 reason);
telsoa014fcda012018-03-09 14:13:49 +0000259 break;
260 }
261 case LayerType::FullyConnected:
262 {
263 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
264 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100265 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
266 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
267
268 TensorInfo biasInfo;
269 const TensorInfo * biasInfoPtr = nullptr;
270 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
271 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
272 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
273
274 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
275 if (descriptor.m_BiasEnabled)
276 {
277 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
278 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
279 biasInfoPtr = &biasInfo;
280 }
281 else
282 {
283 // If biases are not enabled pass a dummy tensorinfo for the validation
284 switch(input.GetDataType())
285 {
286 case DataType::Float16:
287 {
288 biasInfoPtr = &dummyFloat16Bias;
289 break;
290 }
291 case DataType::Float32:
292 {
293 biasInfoPtr = &dummyFloat32Bias;
294 break;
295 }
296 case DataType::QuantisedAsymm8:
297 {
298 biasInfoPtr = &dummyQA8Bias;
299 break;
300 }
301 default:
302 {
303 BOOST_ASSERT_MSG(false, "Unexpected bias type");
304 }
305 }
306 }
307
David Beck33f0ae02018-10-18 15:13:56 +0100308 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100309 OverrideDataType(input, dataType),
310 OverrideDataType(output, dataType),
311 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
312 *biasInfoPtr,
313 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100314 reason);
telsoa014fcda012018-03-09 14:13:49 +0000315 break;
316 }
317 case LayerType::Input:
318 {
319 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100320 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000321 break;
322 }
323 case LayerType::L2Normalization:
324 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100325 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
326 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
327
telsoa014fcda012018-03-09 14:13:49 +0000328 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100329 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100330
David Beck33f0ae02018-10-18 15:13:56 +0100331 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100332 OverrideDataType(input, dataType),
333 OverrideDataType(output, dataType),
334 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100335 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100336 break;
337 }
338 case LayerType::Lstm:
339 {
340 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
341 const LstmDescriptor& descriptor = cLayer->GetParameters();
342
343 // All inputs.
344 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
345 dataType);
346 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
347 dataType);
348 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
349 dataType);
350 // All outputs
351 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
352 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
353 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
354 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
355
356 // Basic parameters
357 const TensorInfo& inputToForgetWeights
358 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
359 const TensorInfo& inputToCellWeights
360 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
361 const TensorInfo& inputToOutputWeights
362 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
363 const TensorInfo& recurrentToForgetWeights
364 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
365 const TensorInfo& recurrentToCellWeights
366 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
367 const TensorInfo& recurrentToOutputWeights
368 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
369 const TensorInfo& forgetGateBias
370 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
371 const TensorInfo& cellBias
372 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
373 const TensorInfo& outputGateBias
374 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
375
376 // Optional parameters
377 const TensorInfo* inputToInputWeights = nullptr;
378 const TensorInfo* recurrentToInputWeights = nullptr;
379 const TensorInfo* cellToInputWeights = nullptr;
380 const TensorInfo* inputGateBias = nullptr;
381 const TensorInfo* projectionWeights = nullptr;
382 const TensorInfo* projectionBias = nullptr;
383 const TensorInfo* cellToForgetWeights = nullptr;
384 const TensorInfo* cellToOutputWeights = nullptr;
385
386 TensorInfo optInputToInputWeights;
387 TensorInfo optRecurrentToInputWeights;
388 TensorInfo optCellToInputWeights;
389 TensorInfo optInputGateBias;
390 TensorInfo optProjectionWeights;
391 TensorInfo optProjectionBias;
392 TensorInfo optCellToForgetWeights;
393 TensorInfo optCellToOutputWeights;
394
395 if(!descriptor.m_CifgEnabled)
396 {
397 optInputToInputWeights =
398 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
399 inputToInputWeights = &optInputToInputWeights;
400
401 optRecurrentToInputWeights =
402 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
403 recurrentToInputWeights = &optRecurrentToInputWeights;
404 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
405 {
406 optCellToInputWeights =
407 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
408 cellToInputWeights = &optCellToInputWeights;
409 }
410 optInputGateBias =
411 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
412 inputGateBias = &optInputGateBias;
413 }
414
415 if(descriptor.m_ProjectionEnabled)
416 {
417 optProjectionWeights =
418 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
419 projectionWeights = &optProjectionWeights;
420 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
421 {
422 optProjectionBias =
423 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
424 projectionBias = &optProjectionBias;
425 }
426 }
427
428 if(descriptor.m_PeepholeEnabled)
429 {
430 optCellToForgetWeights =
431 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
432 cellToForgetWeights = &optCellToForgetWeights;
433 optCellToOutputWeights =
434 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
435 cellToOutputWeights = &optCellToOutputWeights;
436 }
437
David Beck33f0ae02018-10-18 15:13:56 +0100438 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100439 input,
440 outputStateIn,
441 cellStateIn,
442 scratchBuffer,
443 outputStateOut,
444 cellStateOut,
445 output,
446 descriptor,
447 inputToForgetWeights,
448 inputToCellWeights,
449 inputToOutputWeights,
450 recurrentToForgetWeights,
451 recurrentToCellWeights,
452 recurrentToOutputWeights,
453 forgetGateBias,
454 cellBias,
455 outputGateBias,
456 inputToInputWeights,
457 recurrentToInputWeights,
458 cellToInputWeights,
459 inputGateBias,
460 projectionWeights,
461 projectionBias,
462 cellToForgetWeights,
463 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100464 reason);
telsoa014fcda012018-03-09 14:13:49 +0000465 break;
466 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000467 case LayerType::Maximum:
468 {
469 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
470 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
471 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
472
473 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
474 OverrideDataType(input1, dataType),
475 OverrideDataType(output, dataType),
476 reason);
477 break;
478 }
telsoa014fcda012018-03-09 14:13:49 +0000479 case LayerType::Merger:
480 {
481 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
482
telsoa01c577f2c2018-08-31 09:22:23 +0100483 // Get vector of all inputs.
484 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000485 {
telsoa01c577f2c2018-08-31 09:22:23 +0100486 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000487 };
telsoa01c577f2c2018-08-31 09:22:23 +0100488 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
489 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
490 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000491
telsoa01c577f2c2018-08-31 09:22:23 +0100492 auto getTensorInfoPtr = [](const TensorInfo& info)
493 {
494 return &info;
495 };
496 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
497 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
498 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000499
Nikhil Raj8599a412018-11-19 14:51:07 +0000500 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
501
502 result = layerSupportObject->IsMergerSupported(inputPtrs, output, cLayer->GetParameters(), reason);
telsoa014fcda012018-03-09 14:13:49 +0000503 break;
504 }
505 case LayerType::Multiplication:
506 {
507 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
508 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100509 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100510 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100511 OverrideDataType(input0, dataType),
512 OverrideDataType(input1, dataType),
513 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100514 reason);
telsoa014fcda012018-03-09 14:13:49 +0000515 break;
516 }
517 case LayerType::Normalization:
518 {
519 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
520 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
521 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100522 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
523 OverrideDataType(output, dataType),
524 cLayer->GetParameters(),
525 reason);
telsoa014fcda012018-03-09 14:13:49 +0000526 break;
527 }
528 case LayerType::Output:
529 {
530 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100531 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000532 break;
533 }
534 case LayerType::Permute:
535 {
536 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
537 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
538 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100539 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
540 OverrideDataType(output, dataType),
541 cLayer->GetParameters(),
542 reason);
telsoa014fcda012018-03-09 14:13:49 +0000543 break;
544 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100545 case LayerType::Pad:
546 {
547 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
548 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
549 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100550 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100551 OverrideDataType(input, dataType),
552 OverrideDataType(output, dataType),
553 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100554 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100555 break;
556 }
telsoa014fcda012018-03-09 14:13:49 +0000557 case LayerType::Pooling2d:
558 {
559 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
560 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
561 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100562 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
563 OverrideDataType(output, dataType),
564 cLayer->GetParameters(),
565 reason);
telsoa014fcda012018-03-09 14:13:49 +0000566 break;
567 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100568 case LayerType::Division:
569 {
570 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
571 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
572 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100573 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100574 OverrideDataType(input0, dataType),
575 OverrideDataType(input1, dataType),
576 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100577 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100578 break;
579 }
telsoa014fcda012018-03-09 14:13:49 +0000580 case LayerType::Reshape:
581 {
582 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100583 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000584 break;
585 }
586 case LayerType::ResizeBilinear:
587 {
588 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100589 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000590 break;
591 }
592 case LayerType::Softmax:
593 {
594 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
595 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100596 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100597 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
598 OverrideDataType(output, dataType),
599 cLayer->GetParameters(),
600 reason);
telsoa014fcda012018-03-09 14:13:49 +0000601 break;
602 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000603 case LayerType::SpaceToBatchNd:
604 {
605 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
606 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
607 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
608 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
609 OverrideDataType(output, dataType),
610 cLayer->GetParameters(),
611 reason);
612 break;
613 }
telsoa014fcda012018-03-09 14:13:49 +0000614 case LayerType::Splitter:
615 {
616 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
617 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100618 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
619 cLayer->GetParameters(),
620 reason);
telsoa014fcda012018-03-09 14:13:49 +0000621 break;
622 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000623 case LayerType::StridedSlice:
624 {
625 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
626 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
627 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
628 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
629 OverrideDataType(output, dataType),
630 cLayer->GetParameters(),
631 reason);
632 break;
633 }
David Beckc2044fe2018-09-05 15:00:38 +0100634 case LayerType::Subtraction:
635 {
636 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
637 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
638 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100639 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100640 OverrideDataType(input0, dataType),
641 OverrideDataType(input1, dataType),
642 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100643 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100644 break;
645 }
narpra0132b90462018-09-13 11:07:48 +0100646 case LayerType::Mean:
647 {
648 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
649 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
650 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100651 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100652 OverrideDataType(input, dataType),
653 OverrideDataType(output, dataType),
654 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100655 reason);
narpra0132b90462018-09-13 11:07:48 +0100656 break;
657 }
kevmay0190539692018-11-29 08:40:19 +0000658 case LayerType::Minimum:
659 {
660 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
661 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
662 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
663 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
664 OverrideDataType(input1, dataType),
665 OverrideDataType(output, dataType),
666 reason);
667 break;
668 }
telsoa014fcda012018-03-09 14:13:49 +0000669 default:
670 {
671 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100672 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000673 result = false;
674 break;
675 }
676 }
telsoa014fcda012018-03-09 14:13:49 +0000677 return result;
678}
679
David Beckdcb751f2018-10-03 11:42:42 +0100680bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100681 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100682 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000683{
David Beckdcb751f2018-10-03 11:42:42 +0100684 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100685 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000686}
687
surmeh013537c2c2018-05-18 16:31:43 +0100688}