blob: e7dec49db43abc7baadc7c0d28f601c3ac5abfd5 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
David Beckdcb751f2018-10-03 11:42:42 +01005#include <backends/WorkloadFactory.hpp>
David Beck33f0ae02018-10-18 15:13:56 +01006#include <backends/LayerSupportRegistry.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01007
David Beckb4540be2018-09-24 13:18:27 +01008#include <backends/reference/RefWorkloadFactory.hpp>
David Beck0dbe0ee2018-09-24 15:59:27 +01009#include <backends/neon/NeonWorkloadFactory.hpp>
David Beckac42efd2018-09-26 17:41:13 +010010#include <backends/cl/ClWorkloadFactory.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
David Beckb4540be2018-09-24 13:18:27 +010012#include <armnn/Types.hpp>
13#include <armnn/LayerSupport.hpp>
14#include <Layer.hpp>
15#include <LayersFwd.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016#include "CpuTensorHandle.hpp"
17
18#include <boost/cast.hpp>
19#include <cstring>
20#include <boost/iterator/transform_iterator.hpp>
21
22namespace armnn
23{
24
telsoa01c577f2c2018-08-31 09:22:23 +010025namespace
26{
27 const TensorInfo OverrideDataType(const TensorInfo& info, boost::optional<DataType> type)
28 {
29 if (type == boost::none)
30 {
31 return info;
32 }
33
34 return TensorInfo(info.GetShape(), type.get(), info.GetQuantizationScale(), info.GetQuantizationOffset());
35 }
36
37 boost::optional<DataType> GetBiasTypeFromWeightsType(boost::optional<DataType> weightsType)
38 {
39 if (weightsType == boost::none)
40 {
41 return weightsType;
42 }
43
44 switch(weightsType.get())
45 {
46 case DataType::Float16:
47 case DataType::Float32:
48 return weightsType;
49 case DataType::QuantisedAsymm8:
50 return DataType::Signed32;
51 default:
52 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
53 }
54 return boost::none;
55 }
56}
57
David Beck33f0ae02018-10-18 15:13:56 +010058bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010059 const IConnectableLayer& connectableLayer,
60 boost::optional<DataType> dataType,
61 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000062{
David Beck33f0ae02018-10-18 15:13:56 +010063 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000064 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010065 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
66
David Beck33f0ae02018-10-18 15:13:56 +010067 auto const& layerSupportRegistry = LayerSupportRegistryInstance();
68 auto layerSupportFactory = layerSupportRegistry.GetFactory(backendId);
69 auto layerSupportObject = layerSupportFactory();
70
telsoa014fcda012018-03-09 14:13:49 +000071 switch(layer.GetType())
72 {
73 case LayerType::Activation:
74 {
75 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
76 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010077 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010078 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010079 OverrideDataType(input, dataType),
80 OverrideDataType(output, dataType),
81 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010082 reason);
telsoa014fcda012018-03-09 14:13:49 +000083 break;
84 }
85 case LayerType::Addition:
86 {
87 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
88 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
89 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010090 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010091 OverrideDataType(input0, dataType),
92 OverrideDataType(input1, dataType),
93 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010094 reason);
telsoa014fcda012018-03-09 14:13:49 +000095 break;
96 }
97 case LayerType::BatchNormalization:
98 {
99 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
100 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100101 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
102 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
103 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
104 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
105 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100106 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100107 OverrideDataType(input, dataType),
108 OverrideDataType(output, dataType),
109 OverrideDataType(mean, dataType),
110 OverrideDataType(var, dataType),
111 OverrideDataType(beta, dataType),
112 OverrideDataType(gamma, dataType),
113 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100114 reason);
telsoa014fcda012018-03-09 14:13:49 +0000115 break;
116 }
117 case LayerType::Constant:
118 {
119 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100120 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100121 break;
122 }
123 case LayerType::ConvertFp16ToFp32:
124 {
125 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
126 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100127 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100128 break;
129 }
130 case LayerType::ConvertFp32ToFp16:
131 {
132 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
133 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100134 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000135 break;
136 }
137 case LayerType::Convolution2d:
138 {
139 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100140
141 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
142 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100143 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100144 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
145
arovir01a6824102018-08-28 17:40:45 +0100146 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100147
arovir01a6824102018-08-28 17:40:45 +0100148 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100149 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100150 if (descriptor.m_BiasEnabled)
151 {
David Beck5eec11d2018-10-04 15:43:17 +0100152 biases =
153 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100154 }
155
David Beck33f0ae02018-10-18 15:13:56 +0100156 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100157 input,
158 output,
159 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100160 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100161 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100162 reason);
telsoa014fcda012018-03-09 14:13:49 +0000163 break;
164 }
165 case LayerType::MemCopy:
166 {
telsoa01c577f2c2018-08-31 09:22:23 +0100167 // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
168 // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
David Beck33f0ae02018-10-18 15:13:56 +0100169 result = backendId == Compute::CpuRef || backendId == Compute::Undefined
170 || backendId == Compute::CpuAcc || backendId == Compute::GpuAcc;
171 reason.value() = "Unsupported backend type";
telsoa014fcda012018-03-09 14:13:49 +0000172 break;
173 }
174 case LayerType::DepthwiseConvolution2d:
175 {
176 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100177 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
178 dataType);
179 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
180 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
181
telsoa01c577f2c2018-08-31 09:22:23 +0100182 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100183
184 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100185 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100186 if (descriptor.m_BiasEnabled)
187 {
David Beck5eec11d2018-10-04 15:43:17 +0100188 biases =
189 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100190 }
telsoa01c577f2c2018-08-31 09:22:23 +0100191
David Beck33f0ae02018-10-18 15:13:56 +0100192 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100193 input,
194 output,
195 descriptor,
196 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100197 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100198 reason);
telsoa014fcda012018-03-09 14:13:49 +0000199 break;
200 }
201 case LayerType::FakeQuantization:
202 {
203 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
204 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100205 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
206 cLayer->GetParameters(),
207 reason);
telsoa014fcda012018-03-09 14:13:49 +0000208 break;
209 }
210 case LayerType::Floor:
211 {
212 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
213 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100214 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
215 OverrideDataType(output, dataType),
216 reason);
telsoa014fcda012018-03-09 14:13:49 +0000217 break;
218 }
219 case LayerType::FullyConnected:
220 {
221 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
222 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100223 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
224 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
225
226 TensorInfo biasInfo;
227 const TensorInfo * biasInfoPtr = nullptr;
228 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
229 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
230 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
231
232 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
233 if (descriptor.m_BiasEnabled)
234 {
235 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
236 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
237 biasInfoPtr = &biasInfo;
238 }
239 else
240 {
241 // If biases are not enabled pass a dummy tensorinfo for the validation
242 switch(input.GetDataType())
243 {
244 case DataType::Float16:
245 {
246 biasInfoPtr = &dummyFloat16Bias;
247 break;
248 }
249 case DataType::Float32:
250 {
251 biasInfoPtr = &dummyFloat32Bias;
252 break;
253 }
254 case DataType::QuantisedAsymm8:
255 {
256 biasInfoPtr = &dummyQA8Bias;
257 break;
258 }
259 default:
260 {
261 BOOST_ASSERT_MSG(false, "Unexpected bias type");
262 }
263 }
264 }
265
David Beck33f0ae02018-10-18 15:13:56 +0100266 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100267 OverrideDataType(input, dataType),
268 OverrideDataType(output, dataType),
269 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
270 *biasInfoPtr,
271 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100272 reason);
telsoa014fcda012018-03-09 14:13:49 +0000273 break;
274 }
275 case LayerType::Input:
276 {
277 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100278 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000279 break;
280 }
281 case LayerType::L2Normalization:
282 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100283 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
284 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
285
telsoa014fcda012018-03-09 14:13:49 +0000286 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100287 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100288
David Beck33f0ae02018-10-18 15:13:56 +0100289 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100290 OverrideDataType(input, dataType),
291 OverrideDataType(output, dataType),
292 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100293 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100294 break;
295 }
296 case LayerType::Lstm:
297 {
298 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
299 const LstmDescriptor& descriptor = cLayer->GetParameters();
300
301 // All inputs.
302 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
303 dataType);
304 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
305 dataType);
306 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
307 dataType);
308 // All outputs
309 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
310 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
311 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
312 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
313
314 // Basic parameters
315 const TensorInfo& inputToForgetWeights
316 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
317 const TensorInfo& inputToCellWeights
318 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
319 const TensorInfo& inputToOutputWeights
320 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
321 const TensorInfo& recurrentToForgetWeights
322 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
323 const TensorInfo& recurrentToCellWeights
324 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
325 const TensorInfo& recurrentToOutputWeights
326 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
327 const TensorInfo& forgetGateBias
328 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
329 const TensorInfo& cellBias
330 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
331 const TensorInfo& outputGateBias
332 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
333
334 // Optional parameters
335 const TensorInfo* inputToInputWeights = nullptr;
336 const TensorInfo* recurrentToInputWeights = nullptr;
337 const TensorInfo* cellToInputWeights = nullptr;
338 const TensorInfo* inputGateBias = nullptr;
339 const TensorInfo* projectionWeights = nullptr;
340 const TensorInfo* projectionBias = nullptr;
341 const TensorInfo* cellToForgetWeights = nullptr;
342 const TensorInfo* cellToOutputWeights = nullptr;
343
344 TensorInfo optInputToInputWeights;
345 TensorInfo optRecurrentToInputWeights;
346 TensorInfo optCellToInputWeights;
347 TensorInfo optInputGateBias;
348 TensorInfo optProjectionWeights;
349 TensorInfo optProjectionBias;
350 TensorInfo optCellToForgetWeights;
351 TensorInfo optCellToOutputWeights;
352
353 if(!descriptor.m_CifgEnabled)
354 {
355 optInputToInputWeights =
356 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
357 inputToInputWeights = &optInputToInputWeights;
358
359 optRecurrentToInputWeights =
360 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
361 recurrentToInputWeights = &optRecurrentToInputWeights;
362 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
363 {
364 optCellToInputWeights =
365 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
366 cellToInputWeights = &optCellToInputWeights;
367 }
368 optInputGateBias =
369 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
370 inputGateBias = &optInputGateBias;
371 }
372
373 if(descriptor.m_ProjectionEnabled)
374 {
375 optProjectionWeights =
376 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
377 projectionWeights = &optProjectionWeights;
378 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
379 {
380 optProjectionBias =
381 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
382 projectionBias = &optProjectionBias;
383 }
384 }
385
386 if(descriptor.m_PeepholeEnabled)
387 {
388 optCellToForgetWeights =
389 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
390 cellToForgetWeights = &optCellToForgetWeights;
391 optCellToOutputWeights =
392 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
393 cellToOutputWeights = &optCellToOutputWeights;
394 }
395
David Beck33f0ae02018-10-18 15:13:56 +0100396 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100397 input,
398 outputStateIn,
399 cellStateIn,
400 scratchBuffer,
401 outputStateOut,
402 cellStateOut,
403 output,
404 descriptor,
405 inputToForgetWeights,
406 inputToCellWeights,
407 inputToOutputWeights,
408 recurrentToForgetWeights,
409 recurrentToCellWeights,
410 recurrentToOutputWeights,
411 forgetGateBias,
412 cellBias,
413 outputGateBias,
414 inputToInputWeights,
415 recurrentToInputWeights,
416 cellToInputWeights,
417 inputGateBias,
418 projectionWeights,
419 projectionBias,
420 cellToForgetWeights,
421 cellToOutputWeights,
David Beck33f0ae02018-10-18 15:13:56 +0100422 reason);
telsoa014fcda012018-03-09 14:13:49 +0000423 break;
424 }
425 case LayerType::Merger:
426 {
427 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
428
telsoa01c577f2c2018-08-31 09:22:23 +0100429 // Get vector of all inputs.
430 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000431 {
telsoa01c577f2c2018-08-31 09:22:23 +0100432 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000433 };
telsoa01c577f2c2018-08-31 09:22:23 +0100434 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
435 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
436 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000437
telsoa01c577f2c2018-08-31 09:22:23 +0100438 auto getTensorInfoPtr = [](const TensorInfo& info)
439 {
440 return &info;
441 };
442 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
443 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
444 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000445
David Beck33f0ae02018-10-18 15:13:56 +0100446 result = layerSupportObject->IsMergerSupported(inputPtrs, cLayer->GetParameters(), reason);
telsoa014fcda012018-03-09 14:13:49 +0000447 break;
448 }
449 case LayerType::Multiplication:
450 {
451 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
452 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100453 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100454 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100455 OverrideDataType(input0, dataType),
456 OverrideDataType(input1, dataType),
457 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100458 reason);
telsoa014fcda012018-03-09 14:13:49 +0000459 break;
460 }
461 case LayerType::Normalization:
462 {
463 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
464 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
465 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100466 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
467 OverrideDataType(output, dataType),
468 cLayer->GetParameters(),
469 reason);
telsoa014fcda012018-03-09 14:13:49 +0000470 break;
471 }
472 case LayerType::Output:
473 {
474 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100475 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000476 break;
477 }
478 case LayerType::Permute:
479 {
480 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
481 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
482 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100483 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
484 OverrideDataType(output, dataType),
485 cLayer->GetParameters(),
486 reason);
telsoa014fcda012018-03-09 14:13:49 +0000487 break;
488 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100489 case LayerType::Pad:
490 {
491 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
492 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
493 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100494 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100495 OverrideDataType(input, dataType),
496 OverrideDataType(output, dataType),
497 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100498 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100499 break;
500 }
telsoa014fcda012018-03-09 14:13:49 +0000501 case LayerType::Pooling2d:
502 {
503 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
504 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
505 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100506 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
507 OverrideDataType(output, dataType),
508 cLayer->GetParameters(),
509 reason);
telsoa014fcda012018-03-09 14:13:49 +0000510 break;
511 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100512 case LayerType::Division:
513 {
514 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
515 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
516 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100517 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100518 OverrideDataType(input0, dataType),
519 OverrideDataType(input1, dataType),
520 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100521 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100522 break;
523 }
telsoa014fcda012018-03-09 14:13:49 +0000524 case LayerType::Reshape:
525 {
526 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100527 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000528 break;
529 }
530 case LayerType::ResizeBilinear:
531 {
532 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100533 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000534 break;
535 }
536 case LayerType::Softmax:
537 {
538 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
539 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100540 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100541 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
542 OverrideDataType(output, dataType),
543 cLayer->GetParameters(),
544 reason);
telsoa014fcda012018-03-09 14:13:49 +0000545 break;
546 }
547 case LayerType::Splitter:
548 {
549 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
550 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100551 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
552 cLayer->GetParameters(),
553 reason);
telsoa014fcda012018-03-09 14:13:49 +0000554 break;
555 }
David Beckc2044fe2018-09-05 15:00:38 +0100556 case LayerType::Subtraction:
557 {
558 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
559 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
560 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100561 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100562 OverrideDataType(input0, dataType),
563 OverrideDataType(input1, dataType),
564 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100565 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100566 break;
567 }
narpra0132b90462018-09-13 11:07:48 +0100568 case LayerType::Mean:
569 {
570 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
571 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
572 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100573 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +0100574 OverrideDataType(input, dataType),
575 OverrideDataType(output, dataType),
576 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100577 reason);
narpra0132b90462018-09-13 11:07:48 +0100578 break;
579 }
telsoa014fcda012018-03-09 14:13:49 +0000580 default:
581 {
582 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +0100583 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +0000584 result = false;
585 break;
586 }
587 }
telsoa014fcda012018-03-09 14:13:49 +0000588 return result;
589}
590
David Beckdcb751f2018-10-03 11:42:42 +0100591bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
592 boost::optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100593 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000594{
David Beckdcb751f2018-10-03 11:42:42 +0100595 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +0100596 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000597}
598
surmeh013537c2c2018-05-18 16:31:43 +0100599}