blob: a70097eb820209cdb22345ec8e05891e0e0b40b9 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "WorkloadFactory.hpp"
David Beckb4540be2018-09-24 13:18:27 +01006#include <backends/reference/RefWorkloadFactory.hpp>
David Beck0dbe0ee2018-09-24 15:59:27 +01007#include <backends/neon/NeonWorkloadFactory.hpp>
David Beckac42efd2018-09-26 17:41:13 +01008#include <backends/cl/ClWorkloadFactory.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
David Beckb4540be2018-09-24 13:18:27 +010010#include <armnn/Types.hpp>
11#include <armnn/LayerSupport.hpp>
12#include <Layer.hpp>
13#include <LayersFwd.hpp>
telsoa014fcda012018-03-09 14:13:49 +000014#include "CpuTensorHandle.hpp"
15
16#include <boost/cast.hpp>
17#include <cstring>
18#include <boost/iterator/transform_iterator.hpp>
19
20namespace armnn
21{
22
telsoa01c577f2c2018-08-31 09:22:23 +010023namespace
24{
25 const TensorInfo OverrideDataType(const TensorInfo& info, boost::optional<DataType> type)
26 {
27 if (type == boost::none)
28 {
29 return info;
30 }
31
32 return TensorInfo(info.GetShape(), type.get(), info.GetQuantizationScale(), info.GetQuantizationOffset());
33 }
34
35 boost::optional<DataType> GetBiasTypeFromWeightsType(boost::optional<DataType> weightsType)
36 {
37 if (weightsType == boost::none)
38 {
39 return weightsType;
40 }
41
42 switch(weightsType.get())
43 {
44 case DataType::Float16:
45 case DataType::Float32:
46 return weightsType;
47 case DataType::QuantisedAsymm8:
48 return DataType::Signed32;
49 default:
50 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
51 }
52 return boost::none;
53 }
54}
55
56bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, boost::optional<DataType> dataType,
telsoa014fcda012018-03-09 14:13:49 +000057 std::string& outReasonIfUnsupported)
58{
59 constexpr size_t reasonCapacity = 1024;
60 char reason[reasonCapacity];
61 bool result;
62 switch(layer.GetType())
63 {
64 case LayerType::Activation:
65 {
66 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
67 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010068 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
69 result = IsActivationSupported(compute,
70 OverrideDataType(input, dataType),
71 OverrideDataType(output, dataType),
72 cLayer->GetParameters(),
73 reason,
74 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +000075 break;
76 }
77 case LayerType::Addition:
78 {
79 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
80 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
81 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010082 result = IsAdditionSupported(compute,
83 OverrideDataType(input0, dataType),
84 OverrideDataType(input1, dataType),
85 OverrideDataType(output, dataType),
86 reason,
87 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +000088 break;
89 }
90 case LayerType::BatchNormalization:
91 {
92 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
93 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010094 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
95 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
96 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
97 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
98 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
99 result = IsBatchNormalizationSupported(compute,
100 OverrideDataType(input, dataType),
101 OverrideDataType(output, dataType),
102 OverrideDataType(mean, dataType),
103 OverrideDataType(var, dataType),
104 OverrideDataType(beta, dataType),
105 OverrideDataType(gamma, dataType),
106 cLayer->GetParameters(),
107 reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000108 break;
109 }
110 case LayerType::Constant:
111 {
112 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100113 result = IsConstantSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity);
114 break;
115 }
116 case LayerType::ConvertFp16ToFp32:
117 {
118 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
119 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
120 result = IsConvertFp16ToFp32Supported(compute, input, output, reason, reasonCapacity);
121 break;
122 }
123 case LayerType::ConvertFp32ToFp16:
124 {
125 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
126 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
127 result = IsConvertFp32ToFp16Supported(compute, input, output, reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000128 break;
129 }
130 case LayerType::Convolution2d:
131 {
132 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100133
134 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
135 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100136 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100137 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
138
arovir01a6824102018-08-28 17:40:45 +0100139 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100140
arovir01a6824102018-08-28 17:40:45 +0100141 // Construct optional biases object based on the value of m_BiasEnabled
142 boost::optional<TensorInfo> biases(boost::none);
surmeh013537c2c2018-05-18 16:31:43 +0100143 if (descriptor.m_BiasEnabled)
144 {
arovir01a6824102018-08-28 17:40:45 +0100145 biases = boost::make_optional(
146 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType)));
surmeh013537c2c2018-05-18 16:31:43 +0100147 }
148
149 result = IsConvolution2dSupported(compute,
150 input,
151 output,
152 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100153 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100154 biases,
surmeh013537c2c2018-05-18 16:31:43 +0100155 reason,
156 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000157 break;
158 }
159 case LayerType::MemCopy:
160 {
telsoa01c577f2c2018-08-31 09:22:23 +0100161 // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
162 // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
telsoa014fcda012018-03-09 14:13:49 +0000163 result = compute == Compute::CpuRef || compute == Compute::Undefined
164 || compute == Compute::CpuAcc || compute == Compute::GpuAcc;
165 strcpy(reason, "Unsupported backend type");
166 break;
167 }
168 case LayerType::DepthwiseConvolution2d:
169 {
170 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100171 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
172 dataType);
173 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
174 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
175
telsoa01c577f2c2018-08-31 09:22:23 +0100176 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100177
178 // Construct optional biases object based on the value of m_BiasEnabled
179 boost::optional<TensorInfo> biases(boost::none);
telsoa01c577f2c2018-08-31 09:22:23 +0100180 if (descriptor.m_BiasEnabled)
181 {
arovir01a6824102018-08-28 17:40:45 +0100182 biases = boost::make_optional(
183 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType)));
telsoa01c577f2c2018-08-31 09:22:23 +0100184 }
telsoa01c577f2c2018-08-31 09:22:23 +0100185
186 result = IsDepthwiseConvolutionSupported(compute,
187 input,
188 output,
189 descriptor,
190 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100191 biases,
telsoa01c577f2c2018-08-31 09:22:23 +0100192 reason,
193 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000194 break;
195 }
196 case LayerType::FakeQuantization:
197 {
198 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
199 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100200 result = IsFakeQuantizationSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(),
201 reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000202 break;
203 }
204 case LayerType::Floor:
205 {
206 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
207 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100208 result = IsFloorSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
209 reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000210 break;
211 }
212 case LayerType::FullyConnected:
213 {
214 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
215 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100216 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
217 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
218
219 TensorInfo biasInfo;
220 const TensorInfo * biasInfoPtr = nullptr;
221 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
222 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
223 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
224
225 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
226 if (descriptor.m_BiasEnabled)
227 {
228 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
229 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
230 biasInfoPtr = &biasInfo;
231 }
232 else
233 {
234 // If biases are not enabled pass a dummy tensorinfo for the validation
235 switch(input.GetDataType())
236 {
237 case DataType::Float16:
238 {
239 biasInfoPtr = &dummyFloat16Bias;
240 break;
241 }
242 case DataType::Float32:
243 {
244 biasInfoPtr = &dummyFloat32Bias;
245 break;
246 }
247 case DataType::QuantisedAsymm8:
248 {
249 biasInfoPtr = &dummyQA8Bias;
250 break;
251 }
252 default:
253 {
254 BOOST_ASSERT_MSG(false, "Unexpected bias type");
255 }
256 }
257 }
258
259 result = IsFullyConnectedSupported(compute,
260 OverrideDataType(input, dataType),
261 OverrideDataType(output, dataType),
262 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
263 *biasInfoPtr,
264 descriptor,
265 reason,
266 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000267 break;
268 }
269 case LayerType::Input:
270 {
271 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100272 result = IsInputSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000273 break;
274 }
275 case LayerType::L2Normalization:
276 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100277 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
278 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
279
telsoa014fcda012018-03-09 14:13:49 +0000280 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100281 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100282
283 result = IsL2NormalizationSupported(compute,
284 OverrideDataType(input, dataType),
285 OverrideDataType(output, dataType),
286 descriptor,
287 reason,
288 reasonCapacity);
telsoa01c577f2c2018-08-31 09:22:23 +0100289 break;
290 }
291 case LayerType::Lstm:
292 {
293 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
294 const LstmDescriptor& descriptor = cLayer->GetParameters();
295
296 // All inputs.
297 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
298 dataType);
299 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
300 dataType);
301 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
302 dataType);
303 // All outputs
304 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
305 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
306 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
307 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
308
309 // Basic parameters
310 const TensorInfo& inputToForgetWeights
311 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
312 const TensorInfo& inputToCellWeights
313 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
314 const TensorInfo& inputToOutputWeights
315 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
316 const TensorInfo& recurrentToForgetWeights
317 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
318 const TensorInfo& recurrentToCellWeights
319 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
320 const TensorInfo& recurrentToOutputWeights
321 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
322 const TensorInfo& forgetGateBias
323 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
324 const TensorInfo& cellBias
325 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
326 const TensorInfo& outputGateBias
327 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
328
329 // Optional parameters
330 const TensorInfo* inputToInputWeights = nullptr;
331 const TensorInfo* recurrentToInputWeights = nullptr;
332 const TensorInfo* cellToInputWeights = nullptr;
333 const TensorInfo* inputGateBias = nullptr;
334 const TensorInfo* projectionWeights = nullptr;
335 const TensorInfo* projectionBias = nullptr;
336 const TensorInfo* cellToForgetWeights = nullptr;
337 const TensorInfo* cellToOutputWeights = nullptr;
338
339 TensorInfo optInputToInputWeights;
340 TensorInfo optRecurrentToInputWeights;
341 TensorInfo optCellToInputWeights;
342 TensorInfo optInputGateBias;
343 TensorInfo optProjectionWeights;
344 TensorInfo optProjectionBias;
345 TensorInfo optCellToForgetWeights;
346 TensorInfo optCellToOutputWeights;
347
348 if(!descriptor.m_CifgEnabled)
349 {
350 optInputToInputWeights =
351 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
352 inputToInputWeights = &optInputToInputWeights;
353
354 optRecurrentToInputWeights =
355 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
356 recurrentToInputWeights = &optRecurrentToInputWeights;
357 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
358 {
359 optCellToInputWeights =
360 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
361 cellToInputWeights = &optCellToInputWeights;
362 }
363 optInputGateBias =
364 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
365 inputGateBias = &optInputGateBias;
366 }
367
368 if(descriptor.m_ProjectionEnabled)
369 {
370 optProjectionWeights =
371 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
372 projectionWeights = &optProjectionWeights;
373 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
374 {
375 optProjectionBias =
376 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
377 projectionBias = &optProjectionBias;
378 }
379 }
380
381 if(descriptor.m_PeepholeEnabled)
382 {
383 optCellToForgetWeights =
384 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
385 cellToForgetWeights = &optCellToForgetWeights;
386 optCellToOutputWeights =
387 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
388 cellToOutputWeights = &optCellToOutputWeights;
389 }
390
391 result = IsLstmSupported(compute,
392 input,
393 outputStateIn,
394 cellStateIn,
395 scratchBuffer,
396 outputStateOut,
397 cellStateOut,
398 output,
399 descriptor,
400 inputToForgetWeights,
401 inputToCellWeights,
402 inputToOutputWeights,
403 recurrentToForgetWeights,
404 recurrentToCellWeights,
405 recurrentToOutputWeights,
406 forgetGateBias,
407 cellBias,
408 outputGateBias,
409 inputToInputWeights,
410 recurrentToInputWeights,
411 cellToInputWeights,
412 inputGateBias,
413 projectionWeights,
414 projectionBias,
415 cellToForgetWeights,
416 cellToOutputWeights,
417 reason,
418 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000419 break;
420 }
421 case LayerType::Merger:
422 {
423 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
424
telsoa01c577f2c2018-08-31 09:22:23 +0100425 // Get vector of all inputs.
426 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000427 {
telsoa01c577f2c2018-08-31 09:22:23 +0100428 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000429 };
telsoa01c577f2c2018-08-31 09:22:23 +0100430 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
431 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
432 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000433
telsoa01c577f2c2018-08-31 09:22:23 +0100434 auto getTensorInfoPtr = [](const TensorInfo& info)
435 {
436 return &info;
437 };
438 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
439 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
440 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000441
telsoa01c577f2c2018-08-31 09:22:23 +0100442 result = IsMergerSupported(compute, inputPtrs, cLayer->GetParameters(), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000443 break;
444 }
445 case LayerType::Multiplication:
446 {
447 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
448 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100449 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
450 result = IsMultiplicationSupported(compute,
451 OverrideDataType(input0, dataType),
452 OverrideDataType(input1, dataType),
453 OverrideDataType(output, dataType),
454 reason,
455 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000456 break;
457 }
458 case LayerType::Normalization:
459 {
460 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
461 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
462 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100463 result = IsNormalizationSupported(compute, OverrideDataType(input, dataType),
464 OverrideDataType(output, dataType), cLayer->GetParameters(), reason,
465 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000466 break;
467 }
468 case LayerType::Output:
469 {
470 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100471 result = IsOutputSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000472 break;
473 }
474 case LayerType::Permute:
475 {
476 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
477 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
478 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100479 result = IsPermuteSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
480 cLayer->GetParameters(), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000481 break;
482 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100483 case LayerType::Pad:
484 {
485 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
486 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
487 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
488 result = IsPadSupported(compute,
489 OverrideDataType(input, dataType),
490 OverrideDataType(output, dataType),
491 cLayer->GetParameters(),
492 reason,
493 reasonCapacity);
494 break;
495 }
telsoa014fcda012018-03-09 14:13:49 +0000496 case LayerType::Pooling2d:
497 {
498 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
499 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
500 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100501 result = IsPooling2dSupported(compute, OverrideDataType(input, dataType),
502 OverrideDataType(output, dataType), cLayer->GetParameters(), reason,
503 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000504 break;
505 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100506 case LayerType::Division:
507 {
508 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
509 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
510 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
511 result = IsDivisionSupported(compute,
512 OverrideDataType(input0, dataType),
513 OverrideDataType(input1, dataType),
514 OverrideDataType(output, dataType),
515 reason,
516 reasonCapacity);
517 break;
518 }
telsoa014fcda012018-03-09 14:13:49 +0000519 case LayerType::Reshape:
520 {
521 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100522 result = IsReshapeSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000523 break;
524 }
525 case LayerType::ResizeBilinear:
526 {
527 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100528 result = IsResizeBilinearSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000529 break;
530 }
531 case LayerType::Softmax:
532 {
533 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
534 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100535 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
536 result = IsSoftmaxSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
537 cLayer->GetParameters(), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000538 break;
539 }
540 case LayerType::Splitter:
541 {
542 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
543 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100544 result = IsSplitterSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(), reason,
545 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000546 break;
547 }
David Beckc2044fe2018-09-05 15:00:38 +0100548 case LayerType::Subtraction:
549 {
550 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
551 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
552 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
553 result = IsSubtractionSupported(compute,
554 OverrideDataType(input0, dataType),
555 OverrideDataType(input1, dataType),
556 OverrideDataType(output, dataType),
557 reason,
558 reasonCapacity);
559 break;
560 }
narpra0132b90462018-09-13 11:07:48 +0100561 case LayerType::Mean:
562 {
563 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
564 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
565 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
566 result = IsMeanSupported(compute,
567 OverrideDataType(input, dataType),
568 OverrideDataType(output, dataType),
569 cLayer->GetParameters(),
570 reason,
571 reasonCapacity);
572 break;
573 }
telsoa014fcda012018-03-09 14:13:49 +0000574 default:
575 {
576 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
577 strcpy(reason, "Unrecognised layer type");
578 result = false;
579 break;
580 }
581 }
582 outReasonIfUnsupported = reason;
583 return result;
584}
585
telsoa01c577f2c2018-08-31 09:22:23 +0100586bool IWorkloadFactory::IsLayerSupported(const Layer& layer, boost::optional<DataType> dataType,
587 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000588{
589 return IsLayerSupported(layer.GetComputeDevice(), layer, dataType, outReasonIfUnsupported);
590}
591
surmeh013537c2c2018-05-18 16:31:43 +0100592}