blob: 83a20e8675afdfb975c3f83ac8bcb4c34c6c502e [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
6#include "CpuTensorHandle.hpp"
7
8#include <Layer.hpp>
9#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +010010
David Beckb4540be2018-09-24 13:18:27 +010011#include <armnn/Types.hpp>
12#include <armnn/LayerSupport.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013
14#include <backendsCommon/LayerSupportRegistry.hpp>
15#include <backendsCommon/WorkloadFactory.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
17#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018#include <boost/iterator/transform_iterator.hpp>
19
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000020#include <cstring>
21
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
telsoa01c577f2c2018-08-31 09:22:23 +010025namespace
26{
telsoa01c577f2c2018-08-31 09:22:23 +010027
David Beck29c75de2018-10-23 13:35:58 +010028const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
29{
30 if (!type)
31 {
32 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010033 }
34
David Beck29c75de2018-10-23 13:35:58 +010035 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010036}
37
David Beck29c75de2018-10-23 13:35:58 +010038Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
39{
40 if (!weightsType)
41 {
42 return weightsType;
43 }
44
45 switch(weightsType.value())
46 {
47 case DataType::Float16:
48 case DataType::Float32:
49 return weightsType;
50 case DataType::QuantisedAsymm8:
51 return DataType::Signed32;
52 default:
53 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
54 }
55 return EmptyOptional();
56}
57
58} // anonymous namespace
59
David Beck33f0ae02018-10-18 15:13:56 +010060bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010061 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010062 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010063 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000064{
David Beck33f0ae02018-10-18 15:13:56 +010065 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000066 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010067 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
68
David Beck33f0ae02018-10-18 15:13:56 +010069 auto const& layerSupportRegistry = LayerSupportRegistryInstance();
70 auto layerSupportFactory = layerSupportRegistry.GetFactory(backendId);
David Beckd4dfa682018-10-24 17:09:46 +010071 auto layerSupportObject = layerSupportFactory(EmptyInitializer());
David Beck33f0ae02018-10-18 15:13:56 +010072
telsoa014fcda012018-03-09 14:13:49 +000073 switch(layer.GetType())
74 {
75 case LayerType::Activation:
76 {
77 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
78 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010079 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010080 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010081 OverrideDataType(input, dataType),
82 OverrideDataType(output, dataType),
83 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010084 reason);
telsoa014fcda012018-03-09 14:13:49 +000085 break;
86 }
87 case LayerType::Addition:
88 {
89 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
90 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
91 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010092 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010093 OverrideDataType(input0, dataType),
94 OverrideDataType(input1, dataType),
95 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010096 reason);
telsoa014fcda012018-03-09 14:13:49 +000097 break;
98 }
99 case LayerType::BatchNormalization:
100 {
101 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
102 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100103 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
104 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
105 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
106 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
107 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100108 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100109 OverrideDataType(input, dataType),
110 OverrideDataType(output, dataType),
111 OverrideDataType(mean, dataType),
112 OverrideDataType(var, dataType),
113 OverrideDataType(beta, dataType),
114 OverrideDataType(gamma, dataType),
115 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100116 reason);
telsoa014fcda012018-03-09 14:13:49 +0000117 break;
118 }
119 case LayerType::Constant:
120 {
121 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100122 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100123 break;
124 }
125 case LayerType::ConvertFp16ToFp32:
126 {
127 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
128 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100129 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100130 break;
131 }
132 case LayerType::ConvertFp32ToFp16:
133 {
134 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
135 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100136 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000137 break;
138 }
139 case LayerType::Convolution2d:
140 {
141 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100142
143 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
144 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100145 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100146 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
147
arovir01a6824102018-08-28 17:40:45 +0100148 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100149
arovir01a6824102018-08-28 17:40:45 +0100150 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100151 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100152 if (descriptor.m_BiasEnabled)
153 {
David Beck5eec11d2018-10-04 15:43:17 +0100154 biases =
155 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100156 }
157
David Beck33f0ae02018-10-18 15:13:56 +0100158 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100159 input,
160 output,
161 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100162 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100163 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100164 reason);
telsoa014fcda012018-03-09 14:13:49 +0000165 break;
166 }
167 case LayerType::MemCopy:
168 {
telsoa01c577f2c2018-08-31 09:22:23 +0100169 // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
170 // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
David Beck33f0ae02018-10-18 15:13:56 +0100171 result = backendId == Compute::CpuRef || backendId == Compute::Undefined
172 || backendId == Compute::CpuAcc || backendId == Compute::GpuAcc;
173 reason.value() = "Unsupported backend type";
telsoa014fcda012018-03-09 14:13:49 +0000174 break;
175 }
176 case LayerType::DepthwiseConvolution2d:
177 {
178 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100179 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
180 dataType);
181 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
182 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
183
telsoa01c577f2c2018-08-31 09:22:23 +0100184 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100185
186 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100187 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100188 if (descriptor.m_BiasEnabled)
189 {
David Beck5eec11d2018-10-04 15:43:17 +0100190 biases =
191 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100192 }
telsoa01c577f2c2018-08-31 09:22:23 +0100193
David Beck33f0ae02018-10-18 15:13:56 +0100194 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100195 input,
196 output,
197 descriptor,
198 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100199 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100200 reason);
telsoa014fcda012018-03-09 14:13:49 +0000201 break;
202 }
203 case LayerType::FakeQuantization:
204 {
205 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
206 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100207 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
208 cLayer->GetParameters(),
209 reason);
telsoa014fcda012018-03-09 14:13:49 +0000210 break;
211 }
212 case LayerType::Floor:
213 {
214 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
215 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100216 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
217 OverrideDataType(output, dataType),
218 reason);
telsoa014fcda012018-03-09 14:13:49 +0000219 break;
220 }
221 case LayerType::FullyConnected:
222 {
223 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
224 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100225 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
226 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
227
228 TensorInfo biasInfo;
229 const TensorInfo * biasInfoPtr = nullptr;
230 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
231 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
232 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
233
234 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
235 if (descriptor.m_BiasEnabled)
236 {
237 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
238 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
239 biasInfoPtr = &biasInfo;
240 }
241 else
242 {
243 // If biases are not enabled pass a dummy tensorinfo for the validation
244 switch(input.GetDataType())
245 {
246 case DataType::Float16:
247 {
248 biasInfoPtr = &dummyFloat16Bias;
249 break;
250 }
251 case DataType::Float32:
252 {
253 biasInfoPtr = &dummyFloat32Bias;
254 break;
255 }
256 case DataType::QuantisedAsymm8:
257 {
258 biasInfoPtr = &dummyQA8Bias;
259 break;
260 }
261 default:
262 {
263 BOOST_ASSERT_MSG(false, "Unexpected bias type");
264 }
265 }
266 }
267
David Beck33f0ae02018-10-18 15:13:56 +0100268 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100269 OverrideDataType(input, dataType),
270 OverrideDataType(output, dataType),
271 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
272 *biasInfoPtr,
273 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100274 reason);
telsoa014fcda012018-03-09 14:13:49 +0000275 break;
276 }
277 case LayerType::Input:
278 {
279 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100280 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000281 break;
282 }
283 case LayerType::L2Normalization:
284 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100285 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
286 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
287
telsoa014fcda012018-03-09 14:13:49 +0000288 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100289 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100290
David Beck33f0ae02018-10-18 15:13:56 +0100291 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100292 OverrideDataType(input, dataType),
293 OverrideDataType(output, dataType),
294 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100295 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100296 break;
297 }
298 case LayerType::Lstm:
299 {
300 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
301 const LstmDescriptor& descriptor = cLayer->GetParameters();
302
303 // All inputs.
304 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
305 dataType);
306 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
307 dataType);
308 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
309 dataType);
310 // All outputs
311 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
312 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
313 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
314 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
315
316 // Basic parameters
317 const TensorInfo& inputToForgetWeights
318 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
319 const TensorInfo& inputToCellWeights
320 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
321 const TensorInfo& inputToOutputWeights
322 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
323 const TensorInfo& recurrentToForgetWeights
324 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
325 const TensorInfo& recurrentToCellWeights
326 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
327 const TensorInfo& recurrentToOutputWeights
328 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
329 const TensorInfo& forgetGateBias
330 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
331 const TensorInfo& cellBias
332 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
333 const TensorInfo& outputGateBias
334 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
335
336 // Optional parameters
337 const TensorInfo* inputToInputWeights = nullptr;
338 const TensorInfo* recurrentToInputWeights = nullptr;
339 const TensorInfo* cellToInputWeights = nullptr;
340 const TensorInfo* inputGateBias = nullptr;
341 const TensorInfo* projectionWeights = nullptr;
342 const TensorInfo* projectionBias = nullptr;
343 const TensorInfo* cellToForgetWeights = nullptr;
344 const TensorInfo* cellToOutputWeights = nullptr;
345
346 TensorInfo optInputToInputWeights;
347 TensorInfo optRecurrentToInputWeights;
348 TensorInfo optCellToInputWeights;
349 TensorInfo optInputGateBias;
350 TensorInfo optProjectionWeights;
351 TensorInfo optProjectionBias;
352 TensorInfo optCellToForgetWeights;
353 TensorInfo optCellToOutputWeights;
354
355 if(!descriptor.m_CifgEnabled)
356 {
357 optInputToInputWeights =
358 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
359 inputToInputWeights = &optInputToInputWeights;
360
361 optRecurrentToInputWeights =
362 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
363 recurrentToInputWeights = &optRecurrentToInputWeights;
364 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
365 {
366 optCellToInputWeights =
367 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
368 cellToInputWeights = &optCellToInputWeights;
369 }
370 optInputGateBias =
371 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
372 inputGateBias = &optInputGateBias;
373 }
374
375 if(descriptor.m_ProjectionEnabled)
376 {
377 optProjectionWeights =
378 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
379 projectionWeights = &optProjectionWeights;
380 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
381 {
382 optProjectionBias =
383 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
384 projectionBias = &optProjectionBias;
385 }
386 }
387
388 if(descriptor.m_PeepholeEnabled)
389 {
390 optCellToForgetWeights =
391 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
392 cellToForgetWeights = &optCellToForgetWeights;
393 optCellToOutputWeights =
394 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
395 cellToOutputWeights = &optCellToOutputWeights;
396 }
397
David Beck33f0ae02018-10-18 15:13:56 +0100398 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100399 input,
400 outputStateIn,
401 cellStateIn,
402 scratchBuffer,
403 outputStateOut,
404 cellStateOut,
405 output,
406 descriptor,
407 inputToForgetWeights,
408 inputToCellWeights,
409 inputToOutputWeights,
410 recurrentToForgetWeights,
411 recurrentToCellWeights,
412 recurrentToOutputWeights,
413 forgetGateBias,
414 cellBias,
415 outputGateBias,
416 inputToInputWeights,
417 recurrentToInputWeights,
418 cellToInputWeights,
419 inputGateBias,
420 projectionWeights,
421 projectionBias,
422 cellToForgetWeights,
423 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100424 reason);
telsoa014fcda012018-03-09 14:13:49 +0000425 break;
426 }
427 case LayerType::Merger:
428 {
429 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
430
telsoa01c577f2c2018-08-31 09:22:23 +0100431 // Get vector of all inputs.
432 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000433 {
telsoa01c577f2c2018-08-31 09:22:23 +0100434 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000435 };
telsoa01c577f2c2018-08-31 09:22:23 +0100436 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
437 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
438 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000439
telsoa01c577f2c2018-08-31 09:22:23 +0100440 auto getTensorInfoPtr = [](const TensorInfo& info)
441 {
442 return &info;
443 };
444 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
445 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
446 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000447
David Beck33f0ae02018-10-18 15:13:56 +0100448 result = layerSupportObject->IsMergerSupported(inputPtrs, cLayer->GetParameters(), reason);
telsoa014fcda012018-03-09 14:13:49 +0000449 break;
450 }
451 case LayerType::Multiplication:
452 {
453 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
454 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100455 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100456 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100457 OverrideDataType(input0, dataType),
458 OverrideDataType(input1, dataType),
459 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100460 reason);
telsoa014fcda012018-03-09 14:13:49 +0000461 break;
462 }
463 case LayerType::Normalization:
464 {
465 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
466 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
467 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100468 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
469 OverrideDataType(output, dataType),
470 cLayer->GetParameters(),
471 reason);
telsoa014fcda012018-03-09 14:13:49 +0000472 break;
473 }
474 case LayerType::Output:
475 {
476 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100477 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000478 break;
479 }
480 case LayerType::Permute:
481 {
482 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
483 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
484 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100485 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
486 OverrideDataType(output, dataType),
487 cLayer->GetParameters(),
488 reason);
telsoa014fcda012018-03-09 14:13:49 +0000489 break;
490 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100491 case LayerType::Pad:
492 {
493 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
494 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
495 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100496 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100497 OverrideDataType(input, dataType),
498 OverrideDataType(output, dataType),
499 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100500 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100501 break;
502 }
telsoa014fcda012018-03-09 14:13:49 +0000503 case LayerType::Pooling2d:
504 {
505 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
506 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
507 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100508 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
509 OverrideDataType(output, dataType),
510 cLayer->GetParameters(),
511 reason);
telsoa014fcda012018-03-09 14:13:49 +0000512 break;
513 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100514 case LayerType::Division:
515 {
516 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
517 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
518 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100519 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100520 OverrideDataType(input0, dataType),
521 OverrideDataType(input1, dataType),
522 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100523 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100524 break;
525 }
telsoa014fcda012018-03-09 14:13:49 +0000526 case LayerType::Reshape:
527 {
528 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100529 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000530 break;
531 }
532 case LayerType::ResizeBilinear:
533 {
534 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100535 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000536 break;
537 }
538 case LayerType::Softmax:
539 {
540 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
541 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100542 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100543 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
544 OverrideDataType(output, dataType),
545 cLayer->GetParameters(),
546 reason);
telsoa014fcda012018-03-09 14:13:49 +0000547 break;
548 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000549 case LayerType::SpaceToBatchNd:
550 {
551 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
552 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
553 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
554 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
555 OverrideDataType(output, dataType),
556 cLayer->GetParameters(),
557 reason);
558 break;
559 }
telsoa014fcda012018-03-09 14:13:49 +0000560 case LayerType::Splitter:
561 {
562 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
563 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100564 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
565 cLayer->GetParameters(),
566 reason);
telsoa014fcda012018-03-09 14:13:49 +0000567 break;
568 }
David Beckc2044fe2018-09-05 15:00:38 +0100569 case LayerType::Subtraction:
570 {
571 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
572 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
573 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100574 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100575 OverrideDataType(input0, dataType),
576 OverrideDataType(input1, dataType),
577 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100578 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100579 break;
580 }
narpra0132b90462018-09-13 11:07:48 +0100581 case LayerType::Mean:
582 {
583 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
584 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
585 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100586 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100587 OverrideDataType(input, dataType),
588 OverrideDataType(output, dataType),
589 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100590 reason);
narpra0132b90462018-09-13 11:07:48 +0100591 break;
592 }
telsoa014fcda012018-03-09 14:13:49 +0000593 default:
594 {
595 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100596 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000597 result = false;
598 break;
599 }
600 }
telsoa014fcda012018-03-09 14:13:49 +0000601 return result;
602}
603
David Beckdcb751f2018-10-03 11:42:42 +0100604bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +0100605 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100606 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000607{
David Beckdcb751f2018-10-03 11:42:42 +0100608 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100609 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000610}
611
surmeh013537c2c2018-05-18 16:31:43 +0100612}