blob: 34bfd7cead68d60e41c491a67c2af1694c63e90b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000011#include <armnn/ILayerSupport.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000012#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010013#include <armnn/utility/PolymorphicDowncast.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000015#include <backendsCommon/WorkloadFactory.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000016#include <armnn/backends/IBackendInternal.hpp>
17#include <backendsCommon/CpuTensorHandle.hpp>
18#include <backendsCommon/WorkloadFactory.hpp>
19
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
telsoa014fcda012018-03-09 14:13:49 +000022#include <boost/iterator/transform_iterator.hpp>
23
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000025#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000026
telsoa014fcda012018-03-09 14:13:49 +000027namespace armnn
28{
29
telsoa01c577f2c2018-08-31 09:22:23 +010030namespace
31{
telsoa01c577f2c2018-08-31 09:22:23 +010032
David Beck29c75de2018-10-23 13:35:58 +010033const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
34{
35 if (!type)
36 {
37 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010038 }
39
David Beck29c75de2018-10-23 13:35:58 +010040 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010041}
42
David Beck29c75de2018-10-23 13:35:58 +010043} // anonymous namespace
44
David Beck33f0ae02018-10-18 15:13:56 +010045bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010046 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010047 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010048 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000049{
David Beck33f0ae02018-10-18 15:13:56 +010050 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000051 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010052 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010053
David Beck111b5d92018-11-12 14:59:37 +000054 auto const& backendRegistry = BackendRegistryInstance();
55 if (!backendRegistry.IsBackendRegistered(backendId))
56 {
57 std::stringstream ss;
58 ss << connectableLayer.GetName() << " is not supported on " << backendId
59 << " because this backend is not registered.";
60
61 outReasonIfUnsupported = ss.str();
62 return false;
63 }
64
65 auto backendFactory = backendRegistry.GetFactory(backendId);
66 auto backendObject = backendFactory();
67 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010068
telsoa014fcda012018-03-09 14:13:49 +000069 switch(layer.GetType())
70 {
71 case LayerType::Activation:
72 {
Jan Eilersbb446e52020-04-02 13:56:54 +010073 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +000074 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010075 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010076 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010077 OverrideDataType(input, dataType),
78 OverrideDataType(output, dataType),
79 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010080 reason);
telsoa014fcda012018-03-09 14:13:49 +000081 break;
82 }
83 case LayerType::Addition:
84 {
85 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
86 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
87 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010088 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010089 OverrideDataType(input0, dataType),
90 OverrideDataType(input1, dataType),
91 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010092 reason);
telsoa014fcda012018-03-09 14:13:49 +000093 break;
94 }
Nikhil Rajee391d52019-09-05 17:50:44 +010095 case LayerType::ArgMinMax:
96 {
Jan Eilersbb446e52020-04-02 13:56:54 +010097 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +010098 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
99
100 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
101 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
102 result = layerSupportObject->IsArgMinMaxSupported(
103 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000104 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100105 descriptor,
106 reason);
107 break;
108 }
telsoa014fcda012018-03-09 14:13:49 +0000109 case LayerType::BatchNormalization:
110 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100111 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000112 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100113 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
114 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
115 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
116 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
117 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100118 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100119 OverrideDataType(input, dataType),
120 OverrideDataType(output, dataType),
121 OverrideDataType(mean, dataType),
122 OverrideDataType(var, dataType),
123 OverrideDataType(beta, dataType),
124 OverrideDataType(gamma, dataType),
125 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100126 reason);
telsoa014fcda012018-03-09 14:13:49 +0000127 break;
128 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000129 case LayerType::BatchToSpaceNd:
130 {
131 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
132 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100133 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000134
135 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
136 OverrideDataType(output, dataType),
137 cLayer->GetParameters(),
138 reason);
139 break;
140 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100141 case LayerType::Comparison:
142 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100143 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100144
145 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
146 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
148
149 result = layerSupportObject->IsComparisonSupported(OverrideDataType(input0, dataType),
150 OverrideDataType(input1, dataType),
151 OverrideDataType(output, DataType::Boolean),
152 cLayer->GetParameters(),
153 reason);
154 break;
155 }
telsoa014fcda012018-03-09 14:13:49 +0000156 case LayerType::Constant:
157 {
158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100159 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100160 break;
161 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000162 case LayerType::ConvertBf16ToFp32:
163 {
164 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
165 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
166 result = layerSupportObject->IsConvertBf16ToFp32Supported(input, output, reason);
167 break;
168 }
telsoa01c577f2c2018-08-31 09:22:23 +0100169 case LayerType::ConvertFp16ToFp32:
170 {
171 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
172 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100173 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100174 break;
175 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000176 case LayerType::ConvertFp32ToBf16:
177 {
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
180 result = layerSupportObject->IsConvertFp32ToBf16Supported(input, output, reason);
181 break;
182 }
telsoa01c577f2c2018-08-31 09:22:23 +0100183 case LayerType::ConvertFp32ToFp16:
184 {
185 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
186 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100187 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000188 break;
189 }
190 case LayerType::Convolution2d:
191 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100192 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100193
194 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
195 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100196 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100197 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
surmeh013537c2c2018-05-18 16:31:43 +0100198
arovir01a6824102018-08-28 17:40:45 +0100199 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100200
arovir01a6824102018-08-28 17:40:45 +0100201 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100202 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100203 if (descriptor.m_BiasEnabled)
204 {
David Beck5eec11d2018-10-04 15:43:17 +0100205 biases =
206 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100207 }
208
David Beck33f0ae02018-10-18 15:13:56 +0100209 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100210 input,
211 output,
212 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100213 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100214 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100215 reason);
telsoa014fcda012018-03-09 14:13:49 +0000216 break;
217 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000218 case LayerType::Debug:
219 {
220 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
221 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
222
223 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
224 OverrideDataType(output, dataType),
225 reason);
226 break;
227 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100228 case LayerType::DepthToSpace:
229 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100230 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100231
232 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
233 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
234
235 result = layerSupportObject->IsDepthToSpaceSupported(OverrideDataType(input, dataType),
236 OverrideDataType(output, dataType),
237 cLayer->GetParameters(),
238 reason);
239 break;
240 }
telsoa014fcda012018-03-09 14:13:49 +0000241 case LayerType::DepthwiseConvolution2d:
242 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100243 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100244 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
245 dataType);
246 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100247 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100248
telsoa01c577f2c2018-08-31 09:22:23 +0100249 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100250
251 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100252 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100253 if (descriptor.m_BiasEnabled)
254 {
David Beck5eec11d2018-10-04 15:43:17 +0100255 biases =
256 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100257 }
telsoa01c577f2c2018-08-31 09:22:23 +0100258
David Beck33f0ae02018-10-18 15:13:56 +0100259 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100260 input,
261 output,
262 descriptor,
263 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100264 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100265 reason);
telsoa014fcda012018-03-09 14:13:49 +0000266 break;
267 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000268 case LayerType::Dequantize:
269 {
270 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
271 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
272
Aron Virginas-Tar87972be2019-11-13 15:16:28 +0000273 result = layerSupportObject->IsDequantizeSupported(input,
274 OverrideDataType(output, dataType),
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000275 reason);
276 break;
277 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000278 case LayerType::DetectionPostProcess:
279 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100280 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000281 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
282 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
283 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
284
285 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
286 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
287 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
288 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
289
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000290 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000291 result = layerSupportObject->IsDetectionPostProcessSupported(boxEncodings,
292 scores,
293 anchors,
294 detectionBoxes,
295 detectionClasses,
296 detectionScores,
297 numDetections,
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000298 descriptor,
299 reason);
300 break;
301 }
josh minor4a3c6102020-01-06 16:40:46 -0600302 case LayerType::ElementwiseUnary:
303 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100304 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600305
306 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
307 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
308
309 result = layerSupportObject->IsElementwiseUnarySupported(OverrideDataType(input, dataType),
310 OverrideDataType(output, dataType),
311 cLayer->GetParameters(),
312 reason);
313 break;
314 }
telsoa014fcda012018-03-09 14:13:49 +0000315 case LayerType::FakeQuantization:
316 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100317 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000318 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100319 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
320 cLayer->GetParameters(),
321 reason);
telsoa014fcda012018-03-09 14:13:49 +0000322 break;
323 }
324 case LayerType::Floor:
325 {
326 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
327 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100328 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
329 OverrideDataType(output, dataType),
330 reason);
telsoa014fcda012018-03-09 14:13:49 +0000331 break;
332 }
333 case LayerType::FullyConnected:
334 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100335 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000336 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100337 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100338 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100339
340 TensorInfo biasInfo;
341 const TensorInfo * biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000342 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100343 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
344 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
345 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
346
347 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
348 if (descriptor.m_BiasEnabled)
349 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100350 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100351 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
352 biasInfoPtr = &biasInfo;
353 }
354 else
355 {
356 // If biases are not enabled pass a dummy tensorinfo for the validation
357 switch(input.GetDataType())
358 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000359 case DataType::BFloat16:
360 {
361 biasInfoPtr = &dummyBFloat16Bias;
362 break;
363 }
telsoa01c577f2c2018-08-31 09:22:23 +0100364 case DataType::Float16:
365 {
366 biasInfoPtr = &dummyFloat16Bias;
367 break;
368 }
369 case DataType::Float32:
370 {
371 biasInfoPtr = &dummyFloat32Bias;
372 break;
373 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000374 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000375 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000376 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000377 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100378 {
379 biasInfoPtr = &dummyQA8Bias;
380 break;
381 }
382 default:
383 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100384 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100385 }
386 }
387 }
388
David Beck33f0ae02018-10-18 15:13:56 +0100389 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100390 OverrideDataType(input, dataType),
391 OverrideDataType(output, dataType),
392 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
393 *biasInfoPtr,
394 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100395 reason);
telsoa014fcda012018-03-09 14:13:49 +0000396 break;
397 }
narpra01b89b05f2019-01-16 09:53:09 +0000398 case LayerType::Gather:
399 {
400 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
401 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
402 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
403 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100404 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000405 OverrideDataType(output, dataType),
406 reason);
407 break;
408 }
telsoa014fcda012018-03-09 14:13:49 +0000409 case LayerType::Input:
410 {
411 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100412 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000413 break;
414 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100415 case LayerType::InstanceNormalization:
416 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100417 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100418 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
419
420 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
421 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
422
423 result = layerSupportObject->IsInstanceNormalizationSupported(
424 OverrideDataType(input, dataType),
425 OverrideDataType(output, dataType),
426 descriptor,
427 reason);
428 break;
429 }
telsoa014fcda012018-03-09 14:13:49 +0000430 case LayerType::L2Normalization:
431 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100432 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100433 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
434
telsoa014fcda012018-03-09 14:13:49 +0000435 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100436 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100437
David Beck33f0ae02018-10-18 15:13:56 +0100438 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100439 OverrideDataType(input, dataType),
440 OverrideDataType(output, dataType),
441 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100442 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100443 break;
444 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100445 case LayerType::LogSoftmax:
446 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100447 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100448
449 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
450 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
451
452 result = layerSupportObject->IsLogSoftmaxSupported(OverrideDataType(input, dataType),
453 OverrideDataType(output, dataType),
454 cLayer->GetParameters(),
455 reason);
456 break;
457 }
telsoa01c577f2c2018-08-31 09:22:23 +0100458 case LayerType::Lstm:
459 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100460 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100461 const LstmDescriptor& descriptor = cLayer->GetParameters();
462
463 // All inputs.
464 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
465 dataType);
466 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
467 dataType);
468 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
469 dataType);
470 // All outputs
471 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
472 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
473 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
474 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
475
476 // Basic parameters
477 const TensorInfo& inputToForgetWeights
478 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
479 const TensorInfo& inputToCellWeights
480 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
481 const TensorInfo& inputToOutputWeights
482 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
483 const TensorInfo& recurrentToForgetWeights
484 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
485 const TensorInfo& recurrentToCellWeights
486 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
487 const TensorInfo& recurrentToOutputWeights
488 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
489 const TensorInfo& forgetGateBias
490 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
491 const TensorInfo& cellBias
492 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
493 const TensorInfo& outputGateBias
494 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
495
Jan Eilersd01a83c2019-07-03 18:20:40 +0100496 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100497
Jan Eilersd01a83c2019-07-03 18:20:40 +0100498 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
499 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
500 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
501 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
502 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
503 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
504 paramsInfo.m_ForgetGateBias = &forgetGateBias;
505 paramsInfo.m_CellBias = &cellBias;
506 paramsInfo.m_OutputGateBias = &outputGateBias;
507
508
509 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100510 TensorInfo optInputToInputWeights;
511 TensorInfo optRecurrentToInputWeights;
512 TensorInfo optCellToInputWeights;
513 TensorInfo optInputGateBias;
514 TensorInfo optProjectionWeights;
515 TensorInfo optProjectionBias;
516 TensorInfo optCellToForgetWeights;
517 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100518 TensorInfo optInputLayerNormWeights;
519 TensorInfo optForgetLayerNormWeights;
520 TensorInfo optCellLayerNormWeights;
521 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100522
523 if(!descriptor.m_CifgEnabled)
524 {
525 optInputToInputWeights =
526 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100527 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100528
529 optRecurrentToInputWeights =
530 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100531 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100532 optInputGateBias =
533 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100534 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100535 }
536
537 if(descriptor.m_ProjectionEnabled)
538 {
539 optProjectionWeights =
540 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100541 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100542 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
543 {
544 optProjectionBias =
545 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100546 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100547 }
548 }
549
550 if(descriptor.m_PeepholeEnabled)
551 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100552 if(!descriptor.m_CifgEnabled)
553 {
554 optCellToInputWeights =
555 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
556 dataType);
557 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
558 }
telsoa01c577f2c2018-08-31 09:22:23 +0100559 optCellToForgetWeights =
560 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100561 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100562 optCellToOutputWeights =
563 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100564 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100565 }
566
Jan Eilers38e05bd2019-06-26 13:10:09 +0100567 if(descriptor.m_LayerNormEnabled)
568 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100569 if (!descriptor.m_CifgEnabled)
570 {
571 optInputLayerNormWeights = OverrideDataType(
572 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
573 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
574 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100575
576 optForgetLayerNormWeights = OverrideDataType(
577 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100578 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100579
580 optCellLayerNormWeights = OverrideDataType(
581 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100582 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100583
584 optOutputLayerNormWeights = OverrideDataType(
585 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100586 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100587 }
588
David Beck33f0ae02018-10-18 15:13:56 +0100589 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100590 input,
591 outputStateIn,
592 cellStateIn,
593 scratchBuffer,
594 outputStateOut,
595 cellStateOut,
596 output,
597 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100598 paramsInfo,
599 reason);
telsoa014fcda012018-03-09 14:13:49 +0000600 break;
601 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000602 case LayerType::Maximum:
603 {
604 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
605 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
606 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
607
608 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
609 OverrideDataType(input1, dataType),
610 OverrideDataType(output, dataType),
611 reason);
612 break;
613 }
narpra01b89b05f2019-01-16 09:53:09 +0000614 case LayerType::MemCopy:
615 {
616 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
617 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000618
narpra01b89b05f2019-01-16 09:53:09 +0000619 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
620 OverrideDataType(output, dataType),
621 reason);
622 break;
623 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100624 case LayerType::MemImport:
625 {
626 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
627 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
628
629 result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
630 OverrideDataType(output, dataType),
631 reason);
632 break;
633 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100634 case LayerType::Merge:
635 {
636 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
637 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
638 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
639
640 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
641 OverrideDataType(input1, dataType),
642 OverrideDataType(output, dataType),
643 reason);
644 break;
645 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100646 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000647 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100648 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000649
telsoa01c577f2c2018-08-31 09:22:23 +0100650 // Get vector of all inputs.
651 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000652 {
telsoa01c577f2c2018-08-31 09:22:23 +0100653 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000654 };
telsoa01c577f2c2018-08-31 09:22:23 +0100655 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
656 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
657 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000658
telsoa01c577f2c2018-08-31 09:22:23 +0100659 auto getTensorInfoPtr = [](const TensorInfo& info)
660 {
661 return &info;
662 };
663 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
664 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
665 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000666
Nikhil Raj8599a412018-11-19 14:51:07 +0000667 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
668
Jim Flynne242f2d2019-05-22 14:24:13 +0100669 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
670
671
telsoa014fcda012018-03-09 14:13:49 +0000672 break;
673 }
674 case LayerType::Multiplication:
675 {
676 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
677 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100678 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100679 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100680 OverrideDataType(input0, dataType),
681 OverrideDataType(input1, dataType),
682 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100683 reason);
telsoa014fcda012018-03-09 14:13:49 +0000684 break;
685 }
686 case LayerType::Normalization:
687 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100688 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000689 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
690 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100691 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
692 OverrideDataType(output, dataType),
693 cLayer->GetParameters(),
694 reason);
telsoa014fcda012018-03-09 14:13:49 +0000695 break;
696 }
697 case LayerType::Output:
698 {
699 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100700 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000701 break;
702 }
703 case LayerType::Permute:
704 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100705 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000706 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
707 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100708 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
709 OverrideDataType(output, dataType),
710 cLayer->GetParameters(),
711 reason);
telsoa014fcda012018-03-09 14:13:49 +0000712 break;
713 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100714 case LayerType::Pad:
715 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100716 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100717 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
718 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100719 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100720 OverrideDataType(input, dataType),
721 OverrideDataType(output, dataType),
722 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100723 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100724 break;
725 }
telsoa014fcda012018-03-09 14:13:49 +0000726 case LayerType::Pooling2d:
727 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100728 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000729 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
730 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100731 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
732 OverrideDataType(output, dataType),
733 cLayer->GetParameters(),
734 reason);
telsoa014fcda012018-03-09 14:13:49 +0000735 break;
736 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000737 case LayerType::PreCompiled:
738 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100739 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000740 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
741 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
742 cLayer->GetParameters(),
743 reason);
744 break;
745 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000746 case LayerType::Quantize:
747 {
748 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
749 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
750 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
751 break;
752 }
James Conroy586a9aa2020-03-20 08:49:33 +0000753 case LayerType::QLstm:
754 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100755 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000756 const QLstmDescriptor& descriptor = cLayer->GetParameters();
757
758 // Inputs
759 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
760 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
761 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
762
763 // Outputs
764 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
765 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
766 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
767
768 // Lstm parameters
769 LstmInputParamsInfo paramsInfo;
770
771 // Basic parameters
772 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
773 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
774 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
775
776 paramsInfo.m_RecurrentToForgetWeights =
777 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
778 paramsInfo.m_RecurrentToCellWeights =
779 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
780 paramsInfo.m_RecurrentToOutputWeights =
781 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
782
783 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
784 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
785 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
786
787 if(!descriptor.m_CifgEnabled)
788 {
789 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
790 paramsInfo.m_RecurrentToInputWeights =
791 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
792 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
793 }
794
795 if(descriptor.m_ProjectionEnabled)
796 {
797 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100798
799 // Projection bias is optional even if projection is enabled
800 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
801 {
802 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
803 }
James Conroy586a9aa2020-03-20 08:49:33 +0000804 }
805
806 if(descriptor.m_PeepholeEnabled)
807 {
808 if (!descriptor.m_CifgEnabled)
809 {
810 paramsInfo.m_CellToInputWeights =
811 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
812 }
813
814 paramsInfo.m_CellToForgetWeights =
815 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
816 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
817 }
818
819 if(descriptor.m_LayerNormEnabled)
820 {
821 if (!descriptor.m_CifgEnabled)
822 {
823 paramsInfo.m_InputLayerNormWeights =
824 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
825 }
826
827 paramsInfo.m_ForgetLayerNormWeights =
828 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
829 paramsInfo.m_CellLayerNormWeights =
830 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
831 paramsInfo.m_OutputLayerNormWeights =
832 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
833 }
834
835 result = layerSupportObject->IsQLstmSupported(input,
836 previousOutputIn,
837 previousCellStateIn,
838 outputStateOut,
839 cellStateOut,
840 output,
841 descriptor,
842 paramsInfo,
843 reason);
844 break;
845 }
James Conroyee18dc82019-07-17 11:27:46 +0100846 case LayerType::QuantizedLstm:
847 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100848 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +0100849
850 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100851 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
852 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
853 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100854
855 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100856 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
857 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100858
859 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100860 QuantizedLstmInputParamsInfo paramsInfo;
861
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100862 paramsInfo.m_InputToInputWeights =
863 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
864 paramsInfo.m_InputToForgetWeights =
865 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
866 paramsInfo.m_InputToCellWeights =
867 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
868 paramsInfo.m_InputToOutputWeights =
869 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100870
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100871 paramsInfo.m_RecurrentToInputWeights =
872 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
873 paramsInfo.m_RecurrentToForgetWeights =
874 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
875 paramsInfo.m_RecurrentToCellWeights =
876 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
877 paramsInfo.m_RecurrentToOutputWeights =
878 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100879
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100880 paramsInfo.m_InputGateBias =
881 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
882 paramsInfo.m_ForgetGateBias =
883 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
884 paramsInfo.m_CellBias =
885 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
886 paramsInfo.m_OutputGateBias =
887 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100888
889 result = layerSupportObject->IsQuantizedLstmSupported(input,
890 previousCellStateIn,
891 previousOutputIn,
892 cellStateOut,
893 output,
894 paramsInfo,
895 reason);
896 break;
897 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100898 case LayerType::Division:
899 {
900 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
901 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
902 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100903 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100904 OverrideDataType(input0, dataType),
905 OverrideDataType(input1, dataType),
906 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100907 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100908 break;
909 }
telsoa014fcda012018-03-09 14:13:49 +0000910 case LayerType::Reshape:
911 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100912 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000913 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +0000914 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000915 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
Kevin Maya023c402019-12-12 17:28:05 +0000916 OverrideDataType(output, dataType),
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000917 cLayer->GetParameters(),
918 reason);
telsoa014fcda012018-03-09 14:13:49 +0000919 break;
920 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100921 case LayerType::Resize:
922 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100923 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100924 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100925 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
926 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
927 OverrideDataType(output, dataType),
928 cLayer->GetParameters(),
929 reason);
930 break;
931 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100932 case LayerType::Slice:
933 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100934 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100935
936 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
937 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
938
939 result = layerSupportObject->IsSliceSupported(OverrideDataType(input, dataType),
940 OverrideDataType(output, dataType),
941 cLayer->GetParameters(),
942 reason);
943 break;
944 }
telsoa014fcda012018-03-09 14:13:49 +0000945 case LayerType::Softmax:
946 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100947 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000948 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100949 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100950 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
951 OverrideDataType(output, dataType),
952 cLayer->GetParameters(),
953 reason);
telsoa014fcda012018-03-09 14:13:49 +0000954 break;
955 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000956 case LayerType::SpaceToBatchNd:
957 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100958 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000959 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
960 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
961 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
962 OverrideDataType(output, dataType),
963 cLayer->GetParameters(),
964 reason);
965 break;
966 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100967 case LayerType::SpaceToDepth:
968 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100969 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100970
971 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
972 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
973
974 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
975 OverrideDataType(output, dataType),
976 cLayer->GetParameters(),
977 reason);
978 break;
979 }
telsoa014fcda012018-03-09 14:13:49 +0000980 case LayerType::Splitter:
981 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100982 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000983 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100984
985 // Get vector of all outputs.
986 auto getTensorInfo = [&dataType](const OutputSlot& slot)
987 {
988 return OverrideDataType(slot.GetTensorInfo(), dataType);
989 };
990 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
991 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
992 std::vector<TensorInfo> outputs(beginI, endI);
993
994 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
995
David Beck33f0ae02018-10-18 15:13:56 +0100996 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100997 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100998 cLayer->GetParameters(),
999 reason);
telsoa014fcda012018-03-09 14:13:49 +00001000 break;
1001 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001002 case LayerType::Stack:
1003 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001004 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001005
1006 // Get vector of all inputs.
1007 auto getTensorInfo = [&dataType](const InputSlot& slot)
1008 {
1009 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1010 };
1011 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
1012 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
1013 std::vector<TensorInfo> inputs(beginI, endI);
1014
1015 auto getTensorInfoPtr = [](const TensorInfo& info)
1016 {
1017 return &info;
1018 };
1019 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
1020 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
1021 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1022
1023 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1024
1025 result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
1026
1027 break;
1028 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001029 case LayerType::StandIn:
1030 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001031 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001032
1033 // Get vector of all inputs.
1034 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1035 {
1036 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1037 };
1038 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1039 {
1040 return OverrideDataType(slot.GetTensorInfo(), dataType);
1041 };
1042 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1043 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfoIn);
1044 std::vector<TensorInfo> inputs(beginI, endI);
1045
1046 auto beginO = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1047 auto endO = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfoOut);
1048 std::vector<TensorInfo> outputs(beginO, endO);
1049
1050
1051 auto getTensorInfoPtr = [](const TensorInfo& info)
1052 {
1053 return &info;
1054 };
1055 auto beginPtrI = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
1056 auto endPtrI = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
1057 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1058
1059 auto beginPtrO = boost::make_transform_iterator(outputs.begin(), getTensorInfoPtr);
1060 auto endPtrO = boost::make_transform_iterator(outputs.end(), getTensorInfoPtr);
1061 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1062
1063
1064 result = layerSupportObject->IsStandInSupported(inputPtrs,
1065 outputPtrs,
1066 cLayer->GetParameters(),
1067 reason);
1068 break;
1069 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001070 case LayerType::StridedSlice:
1071 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001072 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001073 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1074 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1075 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
1076 OverrideDataType(output, dataType),
1077 cLayer->GetParameters(),
1078 reason);
1079 break;
1080 }
David Beckc2044fe2018-09-05 15:00:38 +01001081 case LayerType::Subtraction:
1082 {
1083 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1084 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1085 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +01001086 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001087 OverrideDataType(input0, dataType),
1088 OverrideDataType(input1, dataType),
1089 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001090 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001091 break;
1092 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001093 case LayerType::Switch:
1094 {
1095 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1096 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1097 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1098 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
1099 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
1100 OverrideDataType(input1, dataType),
1101 OverrideDataType(output0, dataType),
1102 OverrideDataType(output1, dataType),
1103 reason);
1104 break;
1105 }
narpra0132b90462018-09-13 11:07:48 +01001106 case LayerType::Mean:
1107 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001108 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001109 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1110 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +01001111 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001112 OverrideDataType(input, dataType),
1113 OverrideDataType(output, dataType),
1114 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001115 reason);
narpra0132b90462018-09-13 11:07:48 +01001116 break;
1117 }
kevmay0190539692018-11-29 08:40:19 +00001118 case LayerType::Minimum:
1119 {
1120 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1121 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1122 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1123 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
1124 OverrideDataType(input1, dataType),
1125 OverrideDataType(output, dataType),
1126 reason);
1127 break;
1128 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001129 case LayerType::Prelu:
1130 {
1131 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1132 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1133 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1134 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
1135 OverrideDataType(alpha, dataType),
1136 OverrideDataType(output, dataType),
1137 reason);
1138 break;
1139 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001140 case LayerType::Transpose:
1141 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001142 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001143 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1144 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1145 result = layerSupportObject->IsTransposeSupported(OverrideDataType(input, dataType),
1146 OverrideDataType(output, dataType),
1147 cLayer->GetParameters(),
1148 reason);
1149 break;
1150 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001151 case LayerType::TransposeConvolution2d:
1152 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001153 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001154
1155 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1156 dataType);
1157 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1158
1159 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1160
1161 Optional<TensorInfo> biases;
1162 if (descriptor.m_BiasEnabled)
1163 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001164 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001165 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1166 GetBiasTypeFromWeightsType(dataType));
1167 }
1168
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001169 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001170 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1171
1172 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
1173 output,
1174 descriptor,
1175 weights,
1176 biases,
1177 reason);
1178
1179 break;
1180 }
telsoa014fcda012018-03-09 14:13:49 +00001181 default:
1182 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001183 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001184 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001185 result = false;
1186 break;
1187 }
1188 }
telsoa014fcda012018-03-09 14:13:49 +00001189 return result;
1190}
1191
David Beckdcb751f2018-10-03 11:42:42 +01001192bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001193 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001194 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001195{
Jan Eilersbb446e52020-04-02 13:56:54 +01001196 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +01001197 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +00001198}
1199
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001200// Default Implementations
Derek Lamberti901ea112019-12-10 22:07:09 +00001201std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& /*descriptor*/,
1202 const WorkloadInfo& /*info*/) const
Kevin May868eb142019-09-04 17:29:31 +01001203{
1204 return std::unique_ptr<IWorkload>();
1205}
1206
Derek Lamberti901ea112019-12-10 22:07:09 +00001207std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1208 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001209{
1210 return std::unique_ptr<IWorkload>();
1211}
1212
Derek Lamberti901ea112019-12-10 22:07:09 +00001213std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1214 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001215{
1216 return std::unique_ptr<IWorkload>();
1217}
1218
Derek Lamberti901ea112019-12-10 22:07:09 +00001219std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1220 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001221{
1222 return std::unique_ptr<IWorkload>();
1223}
1224
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001225std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001226 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001227{
1228 return std::unique_ptr<IWorkload>();
1229}
1230
Derek Lamberti901ea112019-12-10 22:07:09 +00001231std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1232 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001233{
1234 return std::unique_ptr<IWorkload>();
1235}
1236
Derek Lamberti901ea112019-12-10 22:07:09 +00001237std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1238 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001239{
1240 return std::unique_ptr<IWorkload>();
1241}
1242
Derek Lamberti901ea112019-12-10 22:07:09 +00001243std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1244 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001245{
1246 return std::unique_ptr<IWorkload>();
1247}
1248
Derek Lamberti901ea112019-12-10 22:07:09 +00001249std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1250 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001251{
1252 return std::unique_ptr<IWorkload>();
1253}
1254
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001255std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1256 const WorkloadInfo& /*info*/) const
1257{
1258 return std::unique_ptr<IWorkload>();
1259}
1260
Derek Lamberti901ea112019-12-10 22:07:09 +00001261std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1262 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001263{
1264 return std::unique_ptr<IWorkload>();
1265}
1266
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001267std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1268 const WorkloadInfo& /*info*/) const
1269{
1270 return std::unique_ptr<IWorkload>();
1271}
1272
Derek Lamberti901ea112019-12-10 22:07:09 +00001273std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1274 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001275{
1276 return std::unique_ptr<IWorkload>();
1277}
1278
Derek Lamberti901ea112019-12-10 22:07:09 +00001279std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1280 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001281{
1282 return std::unique_ptr<IWorkload>();
1283}
1284
Derek Lamberti901ea112019-12-10 22:07:09 +00001285std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1286 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001287{
1288 return std::unique_ptr<IWorkload>();
1289}
1290
Derek Lamberti901ea112019-12-10 22:07:09 +00001291std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1292 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001293{
1294 return std::unique_ptr<IWorkload>();
1295}
1296
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001297std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001298 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001299{
1300 return std::unique_ptr<IWorkload>();
1301}
1302
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001303std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001304 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001305{
1306 return std::unique_ptr<IWorkload>();
1307}
1308
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001309std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001310 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001311{
1312 return std::unique_ptr<IWorkload>();
1313}
1314
Derek Lamberti901ea112019-12-10 22:07:09 +00001315std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1316 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001317{
1318 return std::unique_ptr<IWorkload>();
1319}
1320
josh minor4a3c6102020-01-06 16:40:46 -06001321std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1322 const WorkloadInfo& /*info*/) const
1323{
1324 return std::unique_ptr<IWorkload>();
1325}
1326
Derek Lamberti901ea112019-12-10 22:07:09 +00001327std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& /*descriptor*/,
1328 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001329{
1330 return std::unique_ptr<IWorkload>();
1331}
1332
Derek Lamberti901ea112019-12-10 22:07:09 +00001333std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1334 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001335{
1336 return std::unique_ptr<IWorkload>();
1337}
1338
Derek Lamberti901ea112019-12-10 22:07:09 +00001339std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1340 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001341{
1342 return std::unique_ptr<IWorkload>();
1343}
1344
Derek Lamberti901ea112019-12-10 22:07:09 +00001345std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1346 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001347{
1348 return std::unique_ptr<IWorkload>();
1349}
1350
Derek Lamberti901ea112019-12-10 22:07:09 +00001351std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1352 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001353{
1354 return std::unique_ptr<IWorkload>();
1355}
1356
Derek Lamberti901ea112019-12-10 22:07:09 +00001357std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& /*descriptor*/,
1358 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001359{
1360 return std::unique_ptr<IWorkload>();
1361}
1362
Kevin Mayce5045a2019-10-02 14:07:47 +01001363std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001364 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1365 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001366{
1367 return std::unique_ptr<IWorkload>();
1368}
1369
Derek Lamberti901ea112019-12-10 22:07:09 +00001370std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1371 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001372{
1373 return std::unique_ptr<IWorkload>();
1374}
1375
Derek Lamberti901ea112019-12-10 22:07:09 +00001376std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1377 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001378{
1379 return std::unique_ptr<IWorkload>();
1380}
1381
Derek Lamberti901ea112019-12-10 22:07:09 +00001382std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1383 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001384{
1385 return std::unique_ptr<IWorkload>();
1386}
1387
Derek Lamberti901ea112019-12-10 22:07:09 +00001388std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1389 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001390{
1391 return std::unique_ptr<IWorkload>();
1392}
1393
Derek Lamberti901ea112019-12-10 22:07:09 +00001394std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1395 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001396{
1397 return std::unique_ptr<IWorkload>();
1398}
1399
Derek Lamberti901ea112019-12-10 22:07:09 +00001400std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1401 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001402{
1403 return std::unique_ptr<IWorkload>();
1404}
1405
Derek Lamberti901ea112019-12-10 22:07:09 +00001406std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1407 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001408{
1409 return std::unique_ptr<IWorkload>();
1410}
1411
Derek Lamberti901ea112019-12-10 22:07:09 +00001412std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1413 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001414{
1415 return std::unique_ptr<IWorkload>();
1416}
1417
Derek Lamberti901ea112019-12-10 22:07:09 +00001418std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& /*descriptor*/,
1419 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001420{
1421 return std::unique_ptr<IWorkload>();
1422}
1423
Derek Lamberti901ea112019-12-10 22:07:09 +00001424std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1425 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001426{
1427 return std::unique_ptr<IWorkload>();
1428}
1429
Derek Lamberti901ea112019-12-10 22:07:09 +00001430std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1431 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001432{
1433 return std::unique_ptr<IWorkload>();
1434}
1435
Derek Lamberti901ea112019-12-10 22:07:09 +00001436std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1437 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001438{
1439 return std::unique_ptr<IWorkload>();
1440}
1441
Derek Lamberti901ea112019-12-10 22:07:09 +00001442std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1443 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001444{
1445 return std::unique_ptr<IWorkload>();
1446}
1447
Derek Lamberti901ea112019-12-10 22:07:09 +00001448std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1449 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001450{
1451 return std::unique_ptr<IWorkload>();
1452}
1453
Derek Lamberti901ea112019-12-10 22:07:09 +00001454std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001455 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001456{
1457 return std::unique_ptr<IWorkload>();
1458}
1459
Derek Lamberti901ea112019-12-10 22:07:09 +00001460std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1461 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001462{
1463 return std::unique_ptr<IWorkload>();
1464}
1465
Derek Lamberti901ea112019-12-10 22:07:09 +00001466std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1467 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001468{
1469 return std::unique_ptr<IWorkload>();
1470}
1471
Derek Lamberti901ea112019-12-10 22:07:09 +00001472std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1473 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001474{
1475 return std::unique_ptr<IWorkload>();
1476}
1477
Derek Lamberti901ea112019-12-10 22:07:09 +00001478std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1479 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001480{
1481 return std::unique_ptr<IWorkload>();
1482}
1483
James Conroy586a9aa2020-03-20 08:49:33 +00001484std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
1485 const WorkloadInfo& /*info*/) const
1486{
1487 return std::unique_ptr<IWorkload>();
1488}
1489
Derek Lamberti901ea112019-12-10 22:07:09 +00001490std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1491 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001492{
1493 return std::unique_ptr<IWorkload>();
1494}
1495
Derek Lamberti901ea112019-12-10 22:07:09 +00001496std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1497 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001498{
1499 return std::unique_ptr<IWorkload>();
1500}
1501
Derek Lamberti901ea112019-12-10 22:07:09 +00001502std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& /*descriptor*/,
1503 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001504{
1505 return std::unique_ptr<IWorkload>();
1506}
1507
Derek Lamberti901ea112019-12-10 22:07:09 +00001508std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1509 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001510{
1511 return std::unique_ptr<IWorkload>();
1512}
1513
Derek Lamberti901ea112019-12-10 22:07:09 +00001514std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& /*descriptor*/,
1515 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001516{
1517 return std::unique_ptr<IWorkload>();
1518}
1519
Derek Lamberti901ea112019-12-10 22:07:09 +00001520std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1521 const WorkloadInfo& /*info*/) const
1522{
1523 return std::unique_ptr<IWorkload>();
1524}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001525
Derek Lamberti901ea112019-12-10 22:07:09 +00001526std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1527 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001528{
1529 return std::unique_ptr<IWorkload>();
1530}
1531
Derek Lamberti901ea112019-12-10 22:07:09 +00001532std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1533 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001534{
1535 return std::unique_ptr<IWorkload>();
1536}
1537
Derek Lamberti901ea112019-12-10 22:07:09 +00001538std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1539 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001540{
1541 return std::unique_ptr<IWorkload>();
1542}
1543
Derek Lamberti901ea112019-12-10 22:07:09 +00001544std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1545 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001546{
1547 return std::unique_ptr<IWorkload>();
1548}
1549
Derek Lamberti901ea112019-12-10 22:07:09 +00001550std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1551 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001552{
1553 return std::unique_ptr<IWorkload>();
1554}
1555
Derek Lamberti901ea112019-12-10 22:07:09 +00001556std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1557 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001558{
1559 return std::unique_ptr<IWorkload>();
1560}
1561
Derek Lamberti901ea112019-12-10 22:07:09 +00001562std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1563 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001564{
1565 return std::unique_ptr<IWorkload>();
1566}
1567
Derek Lamberti901ea112019-12-10 22:07:09 +00001568std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1569 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001570{
1571 return std::unique_ptr<IWorkload>();
1572}
1573
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001574std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1575 const WorkloadInfo& /*info*/) const
1576{
1577 return std::unique_ptr<IWorkload>();
1578}
1579
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001580std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001581 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1582 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001583{
1584 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001585}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001586
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001587} // namepsace armnn