blob: 1b3f29421acb8eebf244ab42944f3b2ce73e249b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "WorkloadFactory.hpp"
6#include "RefWorkloadFactory.hpp"
7#include "NeonWorkloadFactory.hpp"
8#include "ClWorkloadFactory.hpp"
9
10#include "armnn/Types.hpp"
11#include "armnn/LayerSupport.hpp"
12#include "Layer.hpp"
surmeh013537c2c2018-05-18 16:31:43 +010013#include "LayersFwd.hpp"
telsoa014fcda012018-03-09 14:13:49 +000014#include "CpuTensorHandle.hpp"
15
16#include <boost/cast.hpp>
17#include <cstring>
18#include <boost/iterator/transform_iterator.hpp>
19
20namespace armnn
21{
22
telsoa01c577f2c2018-08-31 09:22:23 +010023namespace
24{
25 const TensorInfo OverrideDataType(const TensorInfo& info, boost::optional<DataType> type)
26 {
27 if (type == boost::none)
28 {
29 return info;
30 }
31
32 return TensorInfo(info.GetShape(), type.get(), info.GetQuantizationScale(), info.GetQuantizationOffset());
33 }
34
35 boost::optional<DataType> GetBiasTypeFromWeightsType(boost::optional<DataType> weightsType)
36 {
37 if (weightsType == boost::none)
38 {
39 return weightsType;
40 }
41
42 switch(weightsType.get())
43 {
44 case DataType::Float16:
45 case DataType::Float32:
46 return weightsType;
47 case DataType::QuantisedAsymm8:
48 return DataType::Signed32;
49 default:
50 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
51 }
52 return boost::none;
53 }
54}
55
56bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, boost::optional<DataType> dataType,
telsoa014fcda012018-03-09 14:13:49 +000057 std::string& outReasonIfUnsupported)
58{
59 constexpr size_t reasonCapacity = 1024;
60 char reason[reasonCapacity];
61 bool result;
62 switch(layer.GetType())
63 {
64 case LayerType::Activation:
65 {
66 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
67 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010068 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
69 result = IsActivationSupported(compute,
70 OverrideDataType(input, dataType),
71 OverrideDataType(output, dataType),
72 cLayer->GetParameters(),
73 reason,
74 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +000075 break;
76 }
77 case LayerType::Addition:
78 {
79 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
80 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
81 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010082 result = IsAdditionSupported(compute,
83 OverrideDataType(input0, dataType),
84 OverrideDataType(input1, dataType),
85 OverrideDataType(output, dataType),
86 reason,
87 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +000088 break;
89 }
90 case LayerType::BatchNormalization:
91 {
92 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
93 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010094 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
95 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
96 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
97 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
98 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
99 result = IsBatchNormalizationSupported(compute,
100 OverrideDataType(input, dataType),
101 OverrideDataType(output, dataType),
102 OverrideDataType(mean, dataType),
103 OverrideDataType(var, dataType),
104 OverrideDataType(beta, dataType),
105 OverrideDataType(gamma, dataType),
106 cLayer->GetParameters(),
107 reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000108 break;
109 }
110 case LayerType::Constant:
111 {
112 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100113 result = IsConstantSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity);
114 break;
115 }
116 case LayerType::ConvertFp16ToFp32:
117 {
118 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
119 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
120 result = IsConvertFp16ToFp32Supported(compute, input, output, reason, reasonCapacity);
121 break;
122 }
123 case LayerType::ConvertFp32ToFp16:
124 {
125 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
126 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
127 result = IsConvertFp32ToFp16Supported(compute, input, output, reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000128 break;
129 }
130 case LayerType::Convolution2d:
131 {
132 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100133 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType);
134 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100135 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
136
telsoa01c577f2c2018-08-31 09:22:23 +0100137 TensorInfo biasInfo;
138 const TensorInfo * biasInfoPtr = nullptr;
139 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
surmeh013537c2c2018-05-18 16:31:43 +0100140 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
141 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
142
143 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
144
145 if (descriptor.m_BiasEnabled)
146 {
147 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100148 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
149 biasInfoPtr = &biasInfo;
surmeh013537c2c2018-05-18 16:31:43 +0100150 }
151 else
152 {
telsoa01c577f2c2018-08-31 09:22:23 +0100153 // If biases are not enabled pass a dummy tensorinfo for the validation.
surmeh013537c2c2018-05-18 16:31:43 +0100154 switch(input.GetDataType())
155 {
telsoa01c577f2c2018-08-31 09:22:23 +0100156 case DataType::Float16:
157 {
158 biasInfoPtr = &dummyFloat16Bias;
159 break;
160 }
surmeh013537c2c2018-05-18 16:31:43 +0100161 case DataType::Float32:
162 {
telsoa01c577f2c2018-08-31 09:22:23 +0100163 biasInfoPtr = &dummyFloat32Bias;
surmeh013537c2c2018-05-18 16:31:43 +0100164 break;
165 }
166 case DataType::QuantisedAsymm8:
167 {
telsoa01c577f2c2018-08-31 09:22:23 +0100168 biasInfoPtr = &dummyQA8Bias;
surmeh013537c2c2018-05-18 16:31:43 +0100169 break;
170 }
171 default:
172 {
173 BOOST_ASSERT_MSG(false, "Unexpected input type");
174 }
175 }
176 }
177
178 result = IsConvolution2dSupported(compute,
179 input,
180 output,
181 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100182 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
183 *biasInfoPtr,
surmeh013537c2c2018-05-18 16:31:43 +0100184 reason,
185 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000186 break;
187 }
188 case LayerType::MemCopy:
189 {
telsoa01c577f2c2018-08-31 09:22:23 +0100190 // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
191 // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
telsoa014fcda012018-03-09 14:13:49 +0000192 result = compute == Compute::CpuRef || compute == Compute::Undefined
193 || compute == Compute::CpuAcc || compute == Compute::GpuAcc;
194 strcpy(reason, "Unsupported backend type");
195 break;
196 }
197 case LayerType::DepthwiseConvolution2d:
198 {
199 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100200 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
201 dataType);
202 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
203 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
204
205 TensorInfo biasInfo;
206 const TensorInfo * biasInfoPtr = nullptr;
207 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
208 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
209 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
210
211 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
212 if (descriptor.m_BiasEnabled)
213 {
214 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
215 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
216 biasInfoPtr = &biasInfo;
217 }
218 else
219 {
220 // If biases are not enabled pass a dummy tensorinfo for the validation
221 switch(input.GetDataType())
222 {
223 case DataType::Float16:
224 {
225 biasInfoPtr = &dummyFloat16Bias;
226 break;
227 }
228 case DataType::Float32:
229 {
230 biasInfoPtr = &dummyFloat32Bias;
231 break;
232 }
233 case DataType::QuantisedAsymm8:
234 {
235 biasInfoPtr = &dummyQA8Bias;
236 break;
237 }
238 default:
239 {
240 BOOST_ASSERT_MSG(false, "Unexpected bias type");
241 }
242 }
243 }
244
245
246 result = IsDepthwiseConvolutionSupported(compute,
247 input,
248 output,
249 descriptor,
250 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
251 *biasInfoPtr,
252 reason,
253 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000254 break;
255 }
256 case LayerType::FakeQuantization:
257 {
258 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
259 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100260 result = IsFakeQuantizationSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(),
261 reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000262 break;
263 }
264 case LayerType::Floor:
265 {
266 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
267 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100268 result = IsFloorSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
269 reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000270 break;
271 }
272 case LayerType::FullyConnected:
273 {
274 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
275 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100276 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
277 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
278
279 TensorInfo biasInfo;
280 const TensorInfo * biasInfoPtr = nullptr;
281 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
282 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
283 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
284
285 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
286 if (descriptor.m_BiasEnabled)
287 {
288 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
289 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
290 biasInfoPtr = &biasInfo;
291 }
292 else
293 {
294 // If biases are not enabled pass a dummy tensorinfo for the validation
295 switch(input.GetDataType())
296 {
297 case DataType::Float16:
298 {
299 biasInfoPtr = &dummyFloat16Bias;
300 break;
301 }
302 case DataType::Float32:
303 {
304 biasInfoPtr = &dummyFloat32Bias;
305 break;
306 }
307 case DataType::QuantisedAsymm8:
308 {
309 biasInfoPtr = &dummyQA8Bias;
310 break;
311 }
312 default:
313 {
314 BOOST_ASSERT_MSG(false, "Unexpected bias type");
315 }
316 }
317 }
318
319 result = IsFullyConnectedSupported(compute,
320 OverrideDataType(input, dataType),
321 OverrideDataType(output, dataType),
322 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
323 *biasInfoPtr,
324 descriptor,
325 reason,
326 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000327 break;
328 }
329 case LayerType::Input:
330 {
331 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100332 result = IsInputSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000333 break;
334 }
335 case LayerType::L2Normalization:
336 {
337 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100338 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
339 result = IsL2NormalizationSupported(compute, OverrideDataType(input, dataType),
340 OverrideDataType(output, dataType), reason, reasonCapacity);
341 break;
342 }
343 case LayerType::Lstm:
344 {
345 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
346 const LstmDescriptor& descriptor = cLayer->GetParameters();
347
348 // All inputs.
349 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
350 dataType);
351 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
352 dataType);
353 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
354 dataType);
355 // All outputs
356 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
357 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
358 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
359 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
360
361 // Basic parameters
362 const TensorInfo& inputToForgetWeights
363 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
364 const TensorInfo& inputToCellWeights
365 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
366 const TensorInfo& inputToOutputWeights
367 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
368 const TensorInfo& recurrentToForgetWeights
369 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
370 const TensorInfo& recurrentToCellWeights
371 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
372 const TensorInfo& recurrentToOutputWeights
373 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
374 const TensorInfo& forgetGateBias
375 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
376 const TensorInfo& cellBias
377 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
378 const TensorInfo& outputGateBias
379 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
380
381 // Optional parameters
382 const TensorInfo* inputToInputWeights = nullptr;
383 const TensorInfo* recurrentToInputWeights = nullptr;
384 const TensorInfo* cellToInputWeights = nullptr;
385 const TensorInfo* inputGateBias = nullptr;
386 const TensorInfo* projectionWeights = nullptr;
387 const TensorInfo* projectionBias = nullptr;
388 const TensorInfo* cellToForgetWeights = nullptr;
389 const TensorInfo* cellToOutputWeights = nullptr;
390
391 TensorInfo optInputToInputWeights;
392 TensorInfo optRecurrentToInputWeights;
393 TensorInfo optCellToInputWeights;
394 TensorInfo optInputGateBias;
395 TensorInfo optProjectionWeights;
396 TensorInfo optProjectionBias;
397 TensorInfo optCellToForgetWeights;
398 TensorInfo optCellToOutputWeights;
399
400 if(!descriptor.m_CifgEnabled)
401 {
402 optInputToInputWeights =
403 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
404 inputToInputWeights = &optInputToInputWeights;
405
406 optRecurrentToInputWeights =
407 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
408 recurrentToInputWeights = &optRecurrentToInputWeights;
409 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
410 {
411 optCellToInputWeights =
412 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
413 cellToInputWeights = &optCellToInputWeights;
414 }
415 optInputGateBias =
416 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
417 inputGateBias = &optInputGateBias;
418 }
419
420 if(descriptor.m_ProjectionEnabled)
421 {
422 optProjectionWeights =
423 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
424 projectionWeights = &optProjectionWeights;
425 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
426 {
427 optProjectionBias =
428 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
429 projectionBias = &optProjectionBias;
430 }
431 }
432
433 if(descriptor.m_PeepholeEnabled)
434 {
435 optCellToForgetWeights =
436 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
437 cellToForgetWeights = &optCellToForgetWeights;
438 optCellToOutputWeights =
439 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
440 cellToOutputWeights = &optCellToOutputWeights;
441 }
442
443 result = IsLstmSupported(compute,
444 input,
445 outputStateIn,
446 cellStateIn,
447 scratchBuffer,
448 outputStateOut,
449 cellStateOut,
450 output,
451 descriptor,
452 inputToForgetWeights,
453 inputToCellWeights,
454 inputToOutputWeights,
455 recurrentToForgetWeights,
456 recurrentToCellWeights,
457 recurrentToOutputWeights,
458 forgetGateBias,
459 cellBias,
460 outputGateBias,
461 inputToInputWeights,
462 recurrentToInputWeights,
463 cellToInputWeights,
464 inputGateBias,
465 projectionWeights,
466 projectionBias,
467 cellToForgetWeights,
468 cellToOutputWeights,
469 reason,
470 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000471 break;
472 }
473 case LayerType::Merger:
474 {
475 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
476
telsoa01c577f2c2018-08-31 09:22:23 +0100477 // Get vector of all inputs.
478 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000479 {
telsoa01c577f2c2018-08-31 09:22:23 +0100480 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000481 };
telsoa01c577f2c2018-08-31 09:22:23 +0100482 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
483 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
484 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000485
telsoa01c577f2c2018-08-31 09:22:23 +0100486 auto getTensorInfoPtr = [](const TensorInfo& info)
487 {
488 return &info;
489 };
490 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
491 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
492 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000493
telsoa01c577f2c2018-08-31 09:22:23 +0100494 result = IsMergerSupported(compute, inputPtrs, cLayer->GetParameters(), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000495 break;
496 }
497 case LayerType::Multiplication:
498 {
499 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
500 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100501 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
502 result = IsMultiplicationSupported(compute,
503 OverrideDataType(input0, dataType),
504 OverrideDataType(input1, dataType),
505 OverrideDataType(output, dataType),
506 reason,
507 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000508 break;
509 }
510 case LayerType::Normalization:
511 {
512 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
513 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
514 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100515 result = IsNormalizationSupported(compute, OverrideDataType(input, dataType),
516 OverrideDataType(output, dataType), cLayer->GetParameters(), reason,
517 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000518 break;
519 }
520 case LayerType::Output:
521 {
522 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100523 result = IsOutputSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000524 break;
525 }
526 case LayerType::Permute:
527 {
528 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
529 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
530 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100531 result = IsPermuteSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
532 cLayer->GetParameters(), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000533 break;
534 }
535 case LayerType::Pooling2d:
536 {
537 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
538 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
539 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100540 result = IsPooling2dSupported(compute, OverrideDataType(input, dataType),
541 OverrideDataType(output, dataType), cLayer->GetParameters(), reason,
542 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000543 break;
544 }
545 case LayerType::Reshape:
546 {
547 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100548 result = IsReshapeSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000549 break;
550 }
551 case LayerType::ResizeBilinear:
552 {
553 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100554 result = IsResizeBilinearSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000555 break;
556 }
557 case LayerType::Softmax:
558 {
559 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
560 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100561 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
562 result = IsSoftmaxSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
563 cLayer->GetParameters(), reason, reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000564 break;
565 }
566 case LayerType::Splitter:
567 {
568 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
569 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100570 result = IsSplitterSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(), reason,
571 reasonCapacity);
telsoa014fcda012018-03-09 14:13:49 +0000572 break;
573 }
574 default:
575 {
576 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
577 strcpy(reason, "Unrecognised layer type");
578 result = false;
579 break;
580 }
581 }
582 outReasonIfUnsupported = reason;
583 return result;
584}
585
telsoa01c577f2c2018-08-31 09:22:23 +0100586bool IWorkloadFactory::IsLayerSupported(const Layer& layer, boost::optional<DataType> dataType,
587 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000588{
589 return IsLayerSupported(layer.GetComputeDevice(), layer, dataType, outReasonIfUnsupported);
590}
591
surmeh013537c2c2018-05-18 16:31:43 +0100592}