blob: 1c18551679686409dfb44d706fd531fcd0041788 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010011#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000017#include <backendsCommon/WorkloadFactory.hpp>
James Conroy1f58f032021-04-27 17:13:27 +010018#include <backendsCommon/TensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000019
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
David Beck111b5d92018-11-12 14:59:37 +000022#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000023
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
telsoa01c577f2c2018-08-31 09:22:23 +010027namespace
28{
Finn Williams3e54d032020-10-22 16:53:35 +010029using LayerList = std::list<Layer*>;
30using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010031
David Beck29c75de2018-10-23 13:35:58 +010032const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
33{
34 if (!type)
35 {
36 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010037 }
38
David Beck29c75de2018-10-23 13:35:58 +010039 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010040}
41
David Beck29c75de2018-10-23 13:35:58 +010042} // anonymous namespace
43
Sadik Armagan045f6be2020-09-10 13:37:32 +010044bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
45 const IConnectableLayer& connectableLayer,
46 Optional<DataType> dataType,
47 std::string& outReasonIfUnsupported,
48 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000049{
David Beck33f0ae02018-10-18 15:13:56 +010050 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000051 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010052 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010053
David Beck111b5d92018-11-12 14:59:37 +000054 auto const& backendRegistry = BackendRegistryInstance();
55 if (!backendRegistry.IsBackendRegistered(backendId))
56 {
57 std::stringstream ss;
58 ss << connectableLayer.GetName() << " is not supported on " << backendId
59 << " because this backend is not registered.";
60
61 outReasonIfUnsupported = ss.str();
62 return false;
63 }
64
65 auto backendFactory = backendRegistry.GetFactory(backendId);
66 auto backendObject = backendFactory();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000067 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
David Beck33f0ae02018-10-18 15:13:56 +010068
telsoa014fcda012018-03-09 14:13:49 +000069 switch(layer.GetType())
70 {
71 case LayerType::Activation:
72 {
Jan Eilersbb446e52020-04-02 13:56:54 +010073 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +000074 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010075 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000076 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010077 OverrideDataType(input, dataType),
78 OverrideDataType(output, dataType),
79 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010080 reason);
telsoa014fcda012018-03-09 14:13:49 +000081 break;
82 }
83 case LayerType::Addition:
84 {
85 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
86 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
87 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000088 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010089 OverrideDataType(input0, dataType),
90 OverrideDataType(input1, dataType),
91 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010092 reason);
telsoa014fcda012018-03-09 14:13:49 +000093 break;
94 }
Nikhil Rajee391d52019-09-05 17:50:44 +010095 case LayerType::ArgMinMax:
96 {
Jan Eilersbb446e52020-04-02 13:56:54 +010097 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +010098 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
99
100 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
101 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000102 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100103 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000104 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100105 descriptor,
106 reason);
107 break;
108 }
telsoa014fcda012018-03-09 14:13:49 +0000109 case LayerType::BatchNormalization:
110 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100111 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000112 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100113 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
114 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
115 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
116 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
117 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000118 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100119 OverrideDataType(input, dataType),
120 OverrideDataType(output, dataType),
121 OverrideDataType(mean, dataType),
122 OverrideDataType(var, dataType),
123 OverrideDataType(beta, dataType),
124 OverrideDataType(gamma, dataType),
125 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100126 reason);
telsoa014fcda012018-03-09 14:13:49 +0000127 break;
128 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000129 case LayerType::BatchToSpaceNd:
130 {
131 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
132 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100133 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000134
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000135 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
136 OverrideDataType(output, dataType),
137 cLayer->GetParameters(),
138 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000139 break;
140 }
mathad01b392e982021-04-07 12:07:30 +0100141 case LayerType::Cast:
142 {
143 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
144 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
145
146 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
147 OverrideDataType(output, dataType),
148 reason);
149 break;
150 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100151 case LayerType::Comparison:
152 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100153 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100154
155 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
156 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
157 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
158
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000159 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
160 OverrideDataType(input1, dataType),
161 OverrideDataType(output, DataType::Boolean),
162 cLayer->GetParameters(),
163 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100164 break;
165 }
telsoa014fcda012018-03-09 14:13:49 +0000166 case LayerType::Constant:
167 {
168 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000169 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100170 break;
171 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000172 case LayerType::ConvertBf16ToFp32:
173 {
174 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
175 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000176 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000177 break;
178 }
telsoa01c577f2c2018-08-31 09:22:23 +0100179 case LayerType::ConvertFp16ToFp32:
180 {
181 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
182 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000183 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100184 break;
185 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000186 case LayerType::ConvertFp32ToBf16:
187 {
188 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
189 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000190 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000191 break;
192 }
telsoa01c577f2c2018-08-31 09:22:23 +0100193 case LayerType::ConvertFp32ToFp16:
194 {
195 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
196 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000197 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000198 break;
199 }
200 case LayerType::Convolution2d:
201 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100202 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100203
204 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
205 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100206 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100207 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
surmeh013537c2c2018-05-18 16:31:43 +0100208
arovir01a6824102018-08-28 17:40:45 +0100209 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100210
arovir01a6824102018-08-28 17:40:45 +0100211 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100212 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100213 if (descriptor.m_BiasEnabled)
214 {
David Beck5eec11d2018-10-04 15:43:17 +0100215 biases =
216 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100217 }
218
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000219 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100220 input,
221 output,
222 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100223 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100224 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100225 reason);
telsoa014fcda012018-03-09 14:13:49 +0000226 break;
227 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000228 case LayerType::Debug:
229 {
230 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
231 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
232
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000233 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000234 OverrideDataType(output, dataType),
235 reason);
236 break;
237 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100238 case LayerType::DepthToSpace:
239 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100240 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100241
242 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
243 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
244
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000245 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100246 OverrideDataType(output, dataType),
247 cLayer->GetParameters(),
248 reason);
249 break;
250 }
telsoa014fcda012018-03-09 14:13:49 +0000251 case LayerType::DepthwiseConvolution2d:
252 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100253 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100254 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
255 dataType);
256 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100257 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100258
telsoa01c577f2c2018-08-31 09:22:23 +0100259 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100260
261 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100262 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100263 if (descriptor.m_BiasEnabled)
264 {
David Beck5eec11d2018-10-04 15:43:17 +0100265 biases =
266 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100267 }
telsoa01c577f2c2018-08-31 09:22:23 +0100268
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000269 result = layerSupportObject.IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100270 input,
271 output,
272 descriptor,
273 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100274 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100275 reason);
telsoa014fcda012018-03-09 14:13:49 +0000276 break;
277 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000278 case LayerType::Dequantize:
279 {
280 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
281 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
282
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000283 result = layerSupportObject.IsDequantizeSupported(input,
284 OverrideDataType(output, dataType),
285 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000286 break;
287 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000288 case LayerType::DetectionPostProcess:
289 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100290 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000291 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
292 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
293 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
294
295 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
296 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
297 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
298 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
299
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000300 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000301 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
302 scores,
303 anchors,
304 detectionBoxes,
305 detectionClasses,
306 detectionScores,
307 numDetections,
308 descriptor,
309 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000310 break;
311 }
josh minor4a3c6102020-01-06 16:40:46 -0600312 case LayerType::ElementwiseUnary:
313 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100314 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600315
316 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
317 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
318
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000319 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
320 OverrideDataType(output, dataType),
321 cLayer->GetParameters(),
322 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600323 break;
324 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100325 case LayerType::Fill:
326 {
327 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
328 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
329 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
330 const FillDescriptor& descriptor = cLayer->GetParameters();
331
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000332 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100333 OverrideDataType(input, dataType),
334 OverrideDataType(output, dataType),
335 descriptor,
336 reason);
337 break;
338 }
telsoa014fcda012018-03-09 14:13:49 +0000339 case LayerType::FakeQuantization:
340 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100341 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000342 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000343 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
344 cLayer->GetParameters(),
345 reason);
telsoa014fcda012018-03-09 14:13:49 +0000346 break;
347 }
348 case LayerType::Floor:
349 {
350 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
351 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000352 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
353 OverrideDataType(output, dataType),
354 reason);
telsoa014fcda012018-03-09 14:13:49 +0000355 break;
356 }
357 case LayerType::FullyConnected:
358 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100359 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000360 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100361 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000362
363 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
364 TensorInfo weightsInfo;
365 const TensorInfo* weightsInfoPtr = nullptr;
366
367 if (descriptor.m_ConstantWeights)
368 {
369 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
370 weightsInfo = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
371 }
372 else
373 {
374 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
375
376 }
377 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100378
379 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000380 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000381 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100382 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
383 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
384 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
385
telsoa01c577f2c2018-08-31 09:22:23 +0100386 if (descriptor.m_BiasEnabled)
387 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000388 if(descriptor.m_ConstantWeights)
389 {
390 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
391 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
392 biasInfoPtr = &biasInfo;
393 }
394 else
395 {
396 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
397 biasInfoPtr = &biasInfo;
398 }
telsoa01c577f2c2018-08-31 09:22:23 +0100399 }
400 else
401 {
402 // If biases are not enabled pass a dummy tensorinfo for the validation
403 switch(input.GetDataType())
404 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000405 case DataType::BFloat16:
406 {
407 biasInfoPtr = &dummyBFloat16Bias;
408 break;
409 }
telsoa01c577f2c2018-08-31 09:22:23 +0100410 case DataType::Float16:
411 {
412 biasInfoPtr = &dummyFloat16Bias;
413 break;
414 }
415 case DataType::Float32:
416 {
417 biasInfoPtr = &dummyFloat32Bias;
418 break;
419 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000420 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000421 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000422 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000423 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100424 {
425 biasInfoPtr = &dummyQA8Bias;
426 break;
427 }
428 default:
429 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100430 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100431 }
432 }
433 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000434 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100435 OverrideDataType(input, dataType),
436 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000437 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100438 *biasInfoPtr,
439 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100440 reason);
telsoa014fcda012018-03-09 14:13:49 +0000441 break;
442 }
narpra01b89b05f2019-01-16 09:53:09 +0000443 case LayerType::Gather:
444 {
445 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
446 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
447 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100448 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
449 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000450 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
451 input1,
452 OverrideDataType(output, dataType),
453 descriptor,
454 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000455 break;
456 }
telsoa014fcda012018-03-09 14:13:49 +0000457 case LayerType::Input:
458 {
459 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000460 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000461 break;
462 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100463 case LayerType::InstanceNormalization:
464 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100465 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100466 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
467
468 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
469 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
470
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000471 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100472 OverrideDataType(input, dataType),
473 OverrideDataType(output, dataType),
474 descriptor,
475 reason);
476 break;
477 }
telsoa014fcda012018-03-09 14:13:49 +0000478 case LayerType::L2Normalization:
479 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100480 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100481 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
482
telsoa014fcda012018-03-09 14:13:49 +0000483 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100484 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100485
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000486 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100487 OverrideDataType(input, dataType),
488 OverrideDataType(output, dataType),
489 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100490 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100491 break;
492 }
James Conroyaba90cd2020-11-06 16:28:18 +0000493 case LayerType::LogicalBinary:
494 {
495 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
496
497 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
498 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
499 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
500
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000501 result = layerSupportObject.IsLogicalBinarySupported(input0,
502 input1,
503 output,
504 cLayer->GetParameters(),
505 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000506 break;
507 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100508 case LayerType::LogSoftmax:
509 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100510 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100511
512 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
513 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
514
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000515 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
516 OverrideDataType(output, dataType),
517 cLayer->GetParameters(),
518 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100519 break;
520 }
telsoa01c577f2c2018-08-31 09:22:23 +0100521 case LayerType::Lstm:
522 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100523 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100524 const LstmDescriptor& descriptor = cLayer->GetParameters();
525
526 // All inputs.
527 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
528 dataType);
529 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
530 dataType);
531 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
532 dataType);
533 // All outputs
534 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
535 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
536 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
537 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
538
539 // Basic parameters
540 const TensorInfo& inputToForgetWeights
541 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
542 const TensorInfo& inputToCellWeights
543 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
544 const TensorInfo& inputToOutputWeights
545 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
546 const TensorInfo& recurrentToForgetWeights
547 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
548 const TensorInfo& recurrentToCellWeights
549 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
550 const TensorInfo& recurrentToOutputWeights
551 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
552 const TensorInfo& forgetGateBias
553 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
554 const TensorInfo& cellBias
555 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
556 const TensorInfo& outputGateBias
557 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
558
Jan Eilersd01a83c2019-07-03 18:20:40 +0100559 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100560
Jan Eilersd01a83c2019-07-03 18:20:40 +0100561 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
562 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
563 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
564 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
565 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
566 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
567 paramsInfo.m_ForgetGateBias = &forgetGateBias;
568 paramsInfo.m_CellBias = &cellBias;
569 paramsInfo.m_OutputGateBias = &outputGateBias;
570
571
572 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100573 TensorInfo optInputToInputWeights;
574 TensorInfo optRecurrentToInputWeights;
575 TensorInfo optCellToInputWeights;
576 TensorInfo optInputGateBias;
577 TensorInfo optProjectionWeights;
578 TensorInfo optProjectionBias;
579 TensorInfo optCellToForgetWeights;
580 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100581 TensorInfo optInputLayerNormWeights;
582 TensorInfo optForgetLayerNormWeights;
583 TensorInfo optCellLayerNormWeights;
584 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100585
586 if(!descriptor.m_CifgEnabled)
587 {
588 optInputToInputWeights =
589 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100590 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100591
592 optRecurrentToInputWeights =
593 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100594 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100595 optInputGateBias =
596 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100597 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100598 }
599
600 if(descriptor.m_ProjectionEnabled)
601 {
602 optProjectionWeights =
603 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100604 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100605 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
606 {
607 optProjectionBias =
608 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100609 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100610 }
611 }
612
613 if(descriptor.m_PeepholeEnabled)
614 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100615 if(!descriptor.m_CifgEnabled)
616 {
617 optCellToInputWeights =
618 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
619 dataType);
620 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
621 }
telsoa01c577f2c2018-08-31 09:22:23 +0100622 optCellToForgetWeights =
623 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100624 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100625 optCellToOutputWeights =
626 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100627 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100628 }
629
Jan Eilers38e05bd2019-06-26 13:10:09 +0100630 if(descriptor.m_LayerNormEnabled)
631 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100632 if (!descriptor.m_CifgEnabled)
633 {
634 optInputLayerNormWeights = OverrideDataType(
635 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
636 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
637 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100638
639 optForgetLayerNormWeights = OverrideDataType(
640 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100641 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100642
643 optCellLayerNormWeights = OverrideDataType(
644 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100645 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100646
647 optOutputLayerNormWeights = OverrideDataType(
648 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100649 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100650 }
651
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000652 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100653 input,
654 outputStateIn,
655 cellStateIn,
656 scratchBuffer,
657 outputStateOut,
658 cellStateOut,
659 output,
660 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100661 paramsInfo,
662 reason);
telsoa014fcda012018-03-09 14:13:49 +0000663 break;
664 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000665 case LayerType::Maximum:
666 {
667 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
668 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
669 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
670
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000671 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
672 OverrideDataType(input1, dataType),
673 OverrideDataType(output, dataType),
674 reason);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000675 break;
676 }
narpra01b89b05f2019-01-16 09:53:09 +0000677 case LayerType::MemCopy:
678 {
679 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
680 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000681
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000682 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
683 OverrideDataType(output, dataType),
684 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000685 break;
686 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100687 case LayerType::MemImport:
688 {
689 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
690 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
691
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000692 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
693 OverrideDataType(output, dataType),
694 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100695 break;
696 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100697 case LayerType::Merge:
698 {
699 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
700 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
701 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
702
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000703 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
704 OverrideDataType(input1, dataType),
705 OverrideDataType(output, dataType),
706 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100707 break;
708 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100709 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000710 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100711 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000712
telsoa01c577f2c2018-08-31 09:22:23 +0100713 // Get vector of all inputs.
714 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000715 {
telsoa01c577f2c2018-08-31 09:22:23 +0100716 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000717 };
Finn Williams3e54d032020-10-22 16:53:35 +0100718
719 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
720 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100721 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000722
telsoa01c577f2c2018-08-31 09:22:23 +0100723 auto getTensorInfoPtr = [](const TensorInfo& info)
724 {
725 return &info;
726 };
Finn Williams3e54d032020-10-22 16:53:35 +0100727
728 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
729 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100730 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000731
Nikhil Raj8599a412018-11-19 14:51:07 +0000732 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
733
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000734 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100735
736
telsoa014fcda012018-03-09 14:13:49 +0000737 break;
738 }
739 case LayerType::Multiplication:
740 {
741 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
742 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100743 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000744 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100745 OverrideDataType(input0, dataType),
746 OverrideDataType(input1, dataType),
747 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100748 reason);
telsoa014fcda012018-03-09 14:13:49 +0000749 break;
750 }
751 case LayerType::Normalization:
752 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100753 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000754 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
755 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000756 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
757 OverrideDataType(output, dataType),
758 cLayer->GetParameters(),
759 reason);
telsoa014fcda012018-03-09 14:13:49 +0000760 break;
761 }
762 case LayerType::Output:
763 {
764 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000765 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000766 break;
767 }
768 case LayerType::Permute:
769 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100770 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000771 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
772 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000773 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
774 OverrideDataType(output, dataType),
775 cLayer->GetParameters(),
776 reason);
telsoa014fcda012018-03-09 14:13:49 +0000777 break;
778 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100779 case LayerType::Pad:
780 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100781 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100782 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
783 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000784 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100785 OverrideDataType(input, dataType),
786 OverrideDataType(output, dataType),
787 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100788 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100789 break;
790 }
telsoa014fcda012018-03-09 14:13:49 +0000791 case LayerType::Pooling2d:
792 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100793 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000794 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
795 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000796 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
797 OverrideDataType(output, dataType),
798 cLayer->GetParameters(),
799 reason);
telsoa014fcda012018-03-09 14:13:49 +0000800 break;
801 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000802 case LayerType::PreCompiled:
803 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100804 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000805 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000806 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
807 cLayer->GetParameters(),
808 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000809 break;
810 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000811 case LayerType::Quantize:
812 {
813 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
814 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000815 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000816 break;
817 }
James Conroy586a9aa2020-03-20 08:49:33 +0000818 case LayerType::QLstm:
819 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100820 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000821 const QLstmDescriptor& descriptor = cLayer->GetParameters();
822
823 // Inputs
824 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
825 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
826 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
827
828 // Outputs
829 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
830 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
831 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
832
833 // Lstm parameters
834 LstmInputParamsInfo paramsInfo;
835
836 // Basic parameters
Matthew Bentham6f24b1a2021-06-29 15:18:32 +0100837 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() != nullptr);
838 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() != nullptr);
839 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() != nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000840 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
841 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
842 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
843
844 paramsInfo.m_RecurrentToForgetWeights =
845 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
846 paramsInfo.m_RecurrentToCellWeights =
847 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
848 paramsInfo.m_RecurrentToOutputWeights =
849 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
850
851 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
852 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
853 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
854
855 if(!descriptor.m_CifgEnabled)
856 {
857 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
858 paramsInfo.m_RecurrentToInputWeights =
859 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
860 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
861 }
862
863 if(descriptor.m_ProjectionEnabled)
864 {
865 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100866
867 // Projection bias is optional even if projection is enabled
868 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
869 {
870 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
871 }
James Conroy586a9aa2020-03-20 08:49:33 +0000872 }
873
874 if(descriptor.m_PeepholeEnabled)
875 {
876 if (!descriptor.m_CifgEnabled)
877 {
878 paramsInfo.m_CellToInputWeights =
879 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
880 }
881
882 paramsInfo.m_CellToForgetWeights =
883 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
884 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
885 }
886
887 if(descriptor.m_LayerNormEnabled)
888 {
889 if (!descriptor.m_CifgEnabled)
890 {
891 paramsInfo.m_InputLayerNormWeights =
892 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
893 }
894
895 paramsInfo.m_ForgetLayerNormWeights =
896 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
897 paramsInfo.m_CellLayerNormWeights =
898 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
899 paramsInfo.m_OutputLayerNormWeights =
900 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
901 }
902
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000903 result = layerSupportObject.IsQLstmSupported(input,
904 previousOutputIn,
905 previousCellStateIn,
906 outputStateOut,
907 cellStateOut,
908 output,
909 descriptor,
910 paramsInfo,
911 reason);
James Conroy586a9aa2020-03-20 08:49:33 +0000912 break;
913 }
James Conroyee18dc82019-07-17 11:27:46 +0100914 case LayerType::QuantizedLstm:
915 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100916 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +0100917
918 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100919 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
920 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
921 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100922
923 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100924 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
925 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100926
927 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100928 QuantizedLstmInputParamsInfo paramsInfo;
929
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100930 paramsInfo.m_InputToInputWeights =
931 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
932 paramsInfo.m_InputToForgetWeights =
933 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
934 paramsInfo.m_InputToCellWeights =
935 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
936 paramsInfo.m_InputToOutputWeights =
937 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100938
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100939 paramsInfo.m_RecurrentToInputWeights =
940 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
941 paramsInfo.m_RecurrentToForgetWeights =
942 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
943 paramsInfo.m_RecurrentToCellWeights =
944 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
945 paramsInfo.m_RecurrentToOutputWeights =
946 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100947
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100948 paramsInfo.m_InputGateBias =
949 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
950 paramsInfo.m_ForgetGateBias =
951 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
952 paramsInfo.m_CellBias =
953 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
954 paramsInfo.m_OutputGateBias =
955 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100956
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000957 result = layerSupportObject.IsQuantizedLstmSupported(input,
958 previousCellStateIn,
959 previousOutputIn,
960 cellStateOut,
961 output,
962 paramsInfo,
963 reason);
James Conroyee18dc82019-07-17 11:27:46 +0100964 break;
965 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100966 case LayerType::Division:
967 {
968 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
969 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
970 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000971 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100972 OverrideDataType(input0, dataType),
973 OverrideDataType(input1, dataType),
974 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100975 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100976 break;
977 }
Finn Williams2605b232020-06-10 15:53:46 +0100978 case LayerType::Rank:
979 {
980 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
981 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000982 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
983 OverrideDataType(output, dataType),
984 reason);
Finn Williams2605b232020-06-10 15:53:46 +0100985 break;
986 }
telsoa014fcda012018-03-09 14:13:49 +0000987 case LayerType::Reshape:
988 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100989 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000990 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +0000991 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000992 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
993 OverrideDataType(output, dataType),
994 cLayer->GetParameters(),
995 reason);
telsoa014fcda012018-03-09 14:13:49 +0000996 break;
997 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100998 case LayerType::Resize:
999 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001000 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +01001001 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +01001002 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001003 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1004 OverrideDataType(output, dataType),
1005 cLayer->GetParameters(),
1006 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +01001007 break;
1008 }
Keith Davis3ae3f972021-05-21 16:33:48 +01001009 case LayerType::Shape:
1010 {
1011 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1012 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1013
1014 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1015 OverrideDataType(output, dataType),
1016 reason);
1017 break;
1018 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001019 case LayerType::Slice:
1020 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001021 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001022
1023 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1024 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1025
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001026 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1027 OverrideDataType(output, dataType),
1028 cLayer->GetParameters(),
1029 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001030 break;
1031 }
telsoa014fcda012018-03-09 14:13:49 +00001032 case LayerType::Softmax:
1033 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001034 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001035 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001036 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001037 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1038 OverrideDataType(output, dataType),
1039 cLayer->GetParameters(),
1040 reason);
telsoa014fcda012018-03-09 14:13:49 +00001041 break;
1042 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001043 case LayerType::SpaceToBatchNd:
1044 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001045 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001046 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1047 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001048 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1049 OverrideDataType(output, dataType),
1050 cLayer->GetParameters(),
1051 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001052 break;
1053 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001054 case LayerType::SpaceToDepth:
1055 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001056 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001057
1058 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1059 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1060
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001061 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1062 OverrideDataType(output, dataType),
1063 cLayer->GetParameters(),
1064 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001065 break;
1066 }
telsoa014fcda012018-03-09 14:13:49 +00001067 case LayerType::Splitter:
1068 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001069 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001070 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001071
1072 // Get vector of all outputs.
1073 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1074 {
1075 return OverrideDataType(slot.GetTensorInfo(), dataType);
1076 };
Finn Williams3e54d032020-10-22 16:53:35 +01001077 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1078 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001079 std::vector<TensorInfo> outputs(beginI, endI);
1080
1081 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1082
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001083 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1084 outputPtrs,
1085 cLayer->GetParameters(),
1086 reason);
telsoa014fcda012018-03-09 14:13:49 +00001087 break;
1088 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001089 case LayerType::Stack:
1090 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001091 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001092
1093 // Get vector of all inputs.
1094 auto getTensorInfo = [&dataType](const InputSlot& slot)
1095 {
1096 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1097 };
Finn Williams3e54d032020-10-22 16:53:35 +01001098 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1099 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001100 std::vector<TensorInfo> inputs(beginI, endI);
1101
1102 auto getTensorInfoPtr = [](const TensorInfo& info)
1103 {
1104 return &info;
1105 };
Finn Williams3e54d032020-10-22 16:53:35 +01001106 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1107 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001108 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1109
1110 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1111
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001112 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001113
1114 break;
1115 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001116 case LayerType::StandIn:
1117 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001118 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001119
1120 // Get vector of all inputs.
1121 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1122 {
1123 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1124 };
1125 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1126 {
1127 return OverrideDataType(slot.GetTensorInfo(), dataType);
1128 };
Finn Williams3e54d032020-10-22 16:53:35 +01001129 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1130 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001131 std::vector<TensorInfo> inputs(beginI, endI);
1132
Finn Williams3e54d032020-10-22 16:53:35 +01001133 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1134 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001135 std::vector<TensorInfo> outputs(beginO, endO);
1136
1137
1138 auto getTensorInfoPtr = [](const TensorInfo& info)
1139 {
1140 return &info;
1141 };
Finn Williams3e54d032020-10-22 16:53:35 +01001142 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1143 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001144 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1145
Finn Williams3e54d032020-10-22 16:53:35 +01001146 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1147 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001148 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1149
1150
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001151 result = layerSupportObject.IsStandInSupported(inputPtrs,
1152 outputPtrs,
1153 cLayer->GetParameters(),
1154 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001155 break;
1156 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001157 case LayerType::StridedSlice:
1158 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001159 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001160 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1161 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001162 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1163 OverrideDataType(output, dataType),
1164 cLayer->GetParameters(),
1165 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001166 break;
1167 }
David Beckc2044fe2018-09-05 15:00:38 +01001168 case LayerType::Subtraction:
1169 {
1170 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1171 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1172 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001173 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001174 OverrideDataType(input0, dataType),
1175 OverrideDataType(input1, dataType),
1176 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001177 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001178 break;
1179 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001180 case LayerType::Switch:
1181 {
1182 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1183 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1184 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1185 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001186 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1187 OverrideDataType(input1, dataType),
1188 OverrideDataType(output0, dataType),
1189 OverrideDataType(output1, dataType),
1190 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001191 break;
1192 }
narpra0132b90462018-09-13 11:07:48 +01001193 case LayerType::Mean:
1194 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001195 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001196 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1197 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001198 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001199 OverrideDataType(input, dataType),
1200 OverrideDataType(output, dataType),
1201 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001202 reason);
narpra0132b90462018-09-13 11:07:48 +01001203 break;
1204 }
kevmay0190539692018-11-29 08:40:19 +00001205 case LayerType::Minimum:
1206 {
1207 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1208 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1209 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001210 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1211 OverrideDataType(input1, dataType),
1212 OverrideDataType(output, dataType),
1213 reason);
kevmay0190539692018-11-29 08:40:19 +00001214 break;
1215 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001216 case LayerType::Prelu:
1217 {
1218 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1219 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1220 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001221 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1222 OverrideDataType(alpha, dataType),
1223 OverrideDataType(output, dataType),
1224 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001225 break;
1226 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001227 case LayerType::Transpose:
1228 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001229 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001230 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1231 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001232 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1233 OverrideDataType(output, dataType),
1234 cLayer->GetParameters(),
1235 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001236 break;
1237 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001238 case LayerType::TransposeConvolution2d:
1239 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001240 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001241
1242 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1243 dataType);
1244 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1245
1246 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1247
1248 Optional<TensorInfo> biases;
1249 if (descriptor.m_BiasEnabled)
1250 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001251 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001252 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1253 GetBiasTypeFromWeightsType(dataType));
1254 }
1255
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001256 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001257 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1258
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001259 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1260 output,
1261 descriptor,
1262 weights,
1263 biases,
1264 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001265
1266 break;
1267 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001268 case LayerType::Reduce:
1269 {
1270 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1271 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1272 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1273
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001274 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1275 OverrideDataType(output, dataType),
1276 cLayer->GetParameters(),
1277 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001278 break;
1279 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001280 case LayerType::UnidirectionalSequenceLstm:
1281 {
1282 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1283 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1284
1285 // All inputs.
1286 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1287 dataType);
1288 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
1289 dataType);
1290 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
1291 dataType);
1292 // Outputs
1293 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1294
1295 // Basic parameters
1296 const TensorInfo& inputToForgetWeights
1297 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1298 const TensorInfo& inputToCellWeights
1299 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1300 const TensorInfo& inputToOutputWeights
1301 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1302 const TensorInfo& recurrentToForgetWeights
1303 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1304 const TensorInfo& recurrentToCellWeights
1305 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1306 const TensorInfo& recurrentToOutputWeights
1307 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1308 const TensorInfo& forgetGateBias
1309 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1310 const TensorInfo& cellBias
1311 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1312 const TensorInfo& outputGateBias
1313 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1314
1315 LstmInputParamsInfo paramsInfo;
1316
1317 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1318 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1319 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1320 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1321 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1322 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1323 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1324 paramsInfo.m_CellBias = &cellBias;
1325 paramsInfo.m_OutputGateBias = &outputGateBias;
1326
1327 // Optional parameters
1328 TensorInfo optInputToInputWeights;
1329 TensorInfo optRecurrentToInputWeights;
1330 TensorInfo optCellToInputWeights;
1331 TensorInfo optInputGateBias;
1332 TensorInfo optProjectionWeights;
1333 TensorInfo optProjectionBias;
1334 TensorInfo optCellToForgetWeights;
1335 TensorInfo optCellToOutputWeights;
1336 TensorInfo optInputLayerNormWeights;
1337 TensorInfo optForgetLayerNormWeights;
1338 TensorInfo optCellLayerNormWeights;
1339 TensorInfo optOutputLayerNormWeights;
1340
1341 if(!descriptor.m_CifgEnabled)
1342 {
1343 optInputToInputWeights =
1344 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1345 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1346
1347 optRecurrentToInputWeights =
1348 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1349 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1350 optInputGateBias =
1351 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1352 paramsInfo.m_InputGateBias = &optInputGateBias;
1353 }
1354
1355 if(descriptor.m_ProjectionEnabled)
1356 {
1357 optProjectionWeights =
1358 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1359 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1360 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1361 {
1362 optProjectionBias =
1363 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1364 paramsInfo.m_ProjectionBias = &optProjectionBias;
1365 }
1366 }
1367
1368 if(descriptor.m_PeepholeEnabled)
1369 {
1370 if(!descriptor.m_CifgEnabled)
1371 {
1372 optCellToInputWeights =
1373 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1374 dataType);
1375 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1376 }
1377 optCellToForgetWeights =
1378 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1379 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1380 optCellToOutputWeights =
1381 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1382 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1383 }
1384
1385 if(descriptor.m_LayerNormEnabled)
1386 {
1387 if (!descriptor.m_CifgEnabled)
1388 {
1389 optInputLayerNormWeights = OverrideDataType(
1390 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1391 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1392 }
1393
1394 optForgetLayerNormWeights = OverrideDataType(
1395 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1396 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1397
1398 optCellLayerNormWeights = OverrideDataType(
1399 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1400 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1401
1402 optOutputLayerNormWeights = OverrideDataType(
1403 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1404 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1405 }
1406
1407 Optional<TensorInfo> hiddenStateOut;
1408 Optional<TensorInfo> cellStateOut;
1409
1410 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1411 outputStateIn,
1412 cellStateIn,
1413 output,
1414 hiddenStateOut,
1415 cellStateOut,
1416 descriptor,
1417 paramsInfo,
1418 reason);
1419 break;
1420 }
telsoa014fcda012018-03-09 14:13:49 +00001421 default:
1422 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001423 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001424 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001425 result = false;
1426 break;
1427 }
1428 }
telsoa014fcda012018-03-09 14:13:49 +00001429 return result;
1430}
1431
Sadik Armagan045f6be2020-09-10 13:37:32 +01001432bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1433 const IConnectableLayer& connectableLayer,
1434 Optional<DataType> dataType,
1435 std::string& outReasonIfUnsupported)
1436{
1437 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1438}
1439
David Beckdcb751f2018-10-03 11:42:42 +01001440bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001441 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001442 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001443{
Jan Eilersbb446e52020-04-02 13:56:54 +01001444 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001445 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1446}
1447
1448// TODO merge with defaulted modelOptions above
1449bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1450 Optional<DataType> dataType,
1451 std::string& outReasonIfUnsupported,
1452 const ModelOptions& modelOptions)
1453{
1454 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1455 return IsLayerConfigurationSupported(layer->GetBackendId(),
1456 connectableLayer,
1457 dataType,
1458 outReasonIfUnsupported,
1459 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001460}
1461
Sadik Armagan04a72972020-09-14 15:44:18 +01001462bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1463 const IConnectableLayer& connectableLayer,
1464 Optional<DataType> dataType,
1465 std::string& outReasonIfUnsupported,
1466 const ModelOptions& modelOptions)
1467{
1468 return IsLayerConfigurationSupported(backendId,
1469 connectableLayer,
1470 dataType,
1471 outReasonIfUnsupported,
1472 modelOptions);
1473}
1474
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001475// Default Implementations
Derek Lamberti901ea112019-12-10 22:07:09 +00001476std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& /*descriptor*/,
1477 const WorkloadInfo& /*info*/) const
Kevin May868eb142019-09-04 17:29:31 +01001478{
1479 return std::unique_ptr<IWorkload>();
1480}
1481
Derek Lamberti901ea112019-12-10 22:07:09 +00001482std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1483 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001484{
1485 return std::unique_ptr<IWorkload>();
1486}
1487
Derek Lamberti901ea112019-12-10 22:07:09 +00001488std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1489 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001490{
1491 return std::unique_ptr<IWorkload>();
1492}
1493
Derek Lamberti901ea112019-12-10 22:07:09 +00001494std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1495 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001496{
1497 return std::unique_ptr<IWorkload>();
1498}
1499
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001500std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001501 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001502{
1503 return std::unique_ptr<IWorkload>();
1504}
1505
Derek Lamberti901ea112019-12-10 22:07:09 +00001506std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1507 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001508{
1509 return std::unique_ptr<IWorkload>();
1510}
1511
mathad01b392e982021-04-07 12:07:30 +01001512std::unique_ptr<IWorkload> IWorkloadFactory::CreateCast(const CastQueueDescriptor& /*descriptor*/,
1513 const WorkloadInfo& /*info*/) const
1514{
1515 return std::unique_ptr<IWorkload>();
1516}
1517
Derek Lamberti901ea112019-12-10 22:07:09 +00001518std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1519 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001520{
1521 return std::unique_ptr<IWorkload>();
1522}
1523
Derek Lamberti901ea112019-12-10 22:07:09 +00001524std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1525 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001526{
1527 return std::unique_ptr<IWorkload>();
1528}
1529
Derek Lamberti901ea112019-12-10 22:07:09 +00001530std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1531 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001532{
1533 return std::unique_ptr<IWorkload>();
1534}
1535
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001536std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1537 const WorkloadInfo& /*info*/) const
1538{
1539 return std::unique_ptr<IWorkload>();
1540}
1541
Derek Lamberti901ea112019-12-10 22:07:09 +00001542std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1543 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001544{
1545 return std::unique_ptr<IWorkload>();
1546}
1547
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001548std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1549 const WorkloadInfo& /*info*/) const
1550{
1551 return std::unique_ptr<IWorkload>();
1552}
1553
Derek Lamberti901ea112019-12-10 22:07:09 +00001554std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1555 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001556{
1557 return std::unique_ptr<IWorkload>();
1558}
1559
Derek Lamberti901ea112019-12-10 22:07:09 +00001560std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1561 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001562{
1563 return std::unique_ptr<IWorkload>();
1564}
1565
Derek Lamberti901ea112019-12-10 22:07:09 +00001566std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1567 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001568{
1569 return std::unique_ptr<IWorkload>();
1570}
1571
Derek Lamberti901ea112019-12-10 22:07:09 +00001572std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1573 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001574{
1575 return std::unique_ptr<IWorkload>();
1576}
1577
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001578std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001579 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001580{
1581 return std::unique_ptr<IWorkload>();
1582}
1583
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001584std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001585 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001586{
1587 return std::unique_ptr<IWorkload>();
1588}
1589
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001590std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001591 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001592{
1593 return std::unique_ptr<IWorkload>();
1594}
1595
Derek Lamberti901ea112019-12-10 22:07:09 +00001596std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1597 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001598{
1599 return std::unique_ptr<IWorkload>();
1600}
1601
josh minor4a3c6102020-01-06 16:40:46 -06001602std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1603 const WorkloadInfo& /*info*/) const
1604{
1605 return std::unique_ptr<IWorkload>();
1606}
1607
Derek Lamberti901ea112019-12-10 22:07:09 +00001608std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& /*descriptor*/,
1609 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001610{
1611 return std::unique_ptr<IWorkload>();
1612}
1613
Derek Lamberti901ea112019-12-10 22:07:09 +00001614std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1615 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001616{
1617 return std::unique_ptr<IWorkload>();
1618}
1619
Ryan OSheaec6c6802020-06-05 17:17:06 +01001620std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
1621 const WorkloadInfo& /*info*/) const
1622{
1623 return std::unique_ptr<IWorkload>();
1624}
1625
Derek Lamberti901ea112019-12-10 22:07:09 +00001626std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1627 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001628{
1629 return std::unique_ptr<IWorkload>();
1630}
1631
Derek Lamberti901ea112019-12-10 22:07:09 +00001632std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1633 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001634{
1635 return std::unique_ptr<IWorkload>();
1636}
1637
Derek Lamberti901ea112019-12-10 22:07:09 +00001638std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1639 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001640{
1641 return std::unique_ptr<IWorkload>();
1642}
1643
Derek Lamberti901ea112019-12-10 22:07:09 +00001644std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& /*descriptor*/,
1645 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001646{
1647 return std::unique_ptr<IWorkload>();
1648}
1649
Kevin Mayce5045a2019-10-02 14:07:47 +01001650std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001651 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1652 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001653{
1654 return std::unique_ptr<IWorkload>();
1655}
1656
Derek Lamberti901ea112019-12-10 22:07:09 +00001657std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1658 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001659{
1660 return std::unique_ptr<IWorkload>();
1661}
1662
James Conroyaba90cd2020-11-06 16:28:18 +00001663std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& /*desc*/,
1664 const WorkloadInfo& /*info*/) const
1665{
1666 return std::unique_ptr<IWorkload>();
1667}
1668
1669std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1670 const WorkloadInfo& /*info*/) const
1671{
1672 return std::unique_ptr<IWorkload>();
1673}
1674
Derek Lamberti901ea112019-12-10 22:07:09 +00001675std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1676 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001677{
1678 return std::unique_ptr<IWorkload>();
1679}
1680
Derek Lamberti901ea112019-12-10 22:07:09 +00001681std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1682 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001683{
1684 return std::unique_ptr<IWorkload>();
1685}
1686
Derek Lamberti901ea112019-12-10 22:07:09 +00001687std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1688 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001689{
1690 return std::unique_ptr<IWorkload>();
1691}
1692
Derek Lamberti901ea112019-12-10 22:07:09 +00001693std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1694 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001695{
1696 return std::unique_ptr<IWorkload>();
1697}
1698
Derek Lamberti901ea112019-12-10 22:07:09 +00001699std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1700 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001701{
1702 return std::unique_ptr<IWorkload>();
1703}
1704
Derek Lamberti901ea112019-12-10 22:07:09 +00001705std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1706 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001707{
1708 return std::unique_ptr<IWorkload>();
1709}
1710
Derek Lamberti901ea112019-12-10 22:07:09 +00001711std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1712 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001713{
1714 return std::unique_ptr<IWorkload>();
1715}
1716
Derek Lamberti901ea112019-12-10 22:07:09 +00001717std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& /*descriptor*/,
1718 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001719{
1720 return std::unique_ptr<IWorkload>();
1721}
1722
Derek Lamberti901ea112019-12-10 22:07:09 +00001723std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1724 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001725{
1726 return std::unique_ptr<IWorkload>();
1727}
1728
Derek Lamberti901ea112019-12-10 22:07:09 +00001729std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1730 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001731{
1732 return std::unique_ptr<IWorkload>();
1733}
1734
Derek Lamberti901ea112019-12-10 22:07:09 +00001735std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1736 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001737{
1738 return std::unique_ptr<IWorkload>();
1739}
1740
Derek Lamberti901ea112019-12-10 22:07:09 +00001741std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1742 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001743{
1744 return std::unique_ptr<IWorkload>();
1745}
1746
Derek Lamberti901ea112019-12-10 22:07:09 +00001747std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1748 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001749{
1750 return std::unique_ptr<IWorkload>();
1751}
1752
Derek Lamberti901ea112019-12-10 22:07:09 +00001753std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001754 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001755{
1756 return std::unique_ptr<IWorkload>();
1757}
1758
Derek Lamberti901ea112019-12-10 22:07:09 +00001759std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1760 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001761{
1762 return std::unique_ptr<IWorkload>();
1763}
1764
Derek Lamberti901ea112019-12-10 22:07:09 +00001765std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1766 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001767{
1768 return std::unique_ptr<IWorkload>();
1769}
1770
Derek Lamberti901ea112019-12-10 22:07:09 +00001771std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1772 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001773{
1774 return std::unique_ptr<IWorkload>();
1775}
1776
Derek Lamberti901ea112019-12-10 22:07:09 +00001777std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1778 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001779{
1780 return std::unique_ptr<IWorkload>();
1781}
1782
James Conroy586a9aa2020-03-20 08:49:33 +00001783std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
1784 const WorkloadInfo& /*info*/) const
1785{
1786 return std::unique_ptr<IWorkload>();
1787}
1788
Derek Lamberti901ea112019-12-10 22:07:09 +00001789std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1790 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001791{
1792 return std::unique_ptr<IWorkload>();
1793}
Finn Williams2605b232020-06-10 15:53:46 +01001794std::unique_ptr<IWorkload> IWorkloadFactory::CreateRank(const RankQueueDescriptor& /*descriptor*/,
1795 const WorkloadInfo& /*info*/) const
1796{
1797 return std::unique_ptr<IWorkload>();
1798}
James Conroyee18dc82019-07-17 11:27:46 +01001799
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001800std::unique_ptr<IWorkload> IWorkloadFactory::CreateReduce(const ReduceQueueDescriptor& /*descriptor*/,
1801 const WorkloadInfo& /*info*/) const
1802{
1803 return std::unique_ptr<IWorkload>();
1804}
1805
Derek Lamberti901ea112019-12-10 22:07:09 +00001806std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1807 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001808{
1809 return std::unique_ptr<IWorkload>();
1810}
1811
Derek Lamberti901ea112019-12-10 22:07:09 +00001812std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& /*descriptor*/,
1813 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001814{
1815 return std::unique_ptr<IWorkload>();
1816}
1817
Derek Lamberti901ea112019-12-10 22:07:09 +00001818std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1819 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001820{
1821 return std::unique_ptr<IWorkload>();
1822}
1823
Derek Lamberti901ea112019-12-10 22:07:09 +00001824std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& /*descriptor*/,
1825 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001826{
1827 return std::unique_ptr<IWorkload>();
1828}
1829
Keith Davis3ae3f972021-05-21 16:33:48 +01001830std::unique_ptr<IWorkload> IWorkloadFactory::CreateShape(const ShapeQueueDescriptor& /*descriptor*/,
1831 const WorkloadInfo& /*info*/) const
1832{
1833 return std::unique_ptr<IWorkload>();
1834}
1835
Derek Lamberti901ea112019-12-10 22:07:09 +00001836std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1837 const WorkloadInfo& /*info*/) const
1838{
1839 return std::unique_ptr<IWorkload>();
1840}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001841
Derek Lamberti901ea112019-12-10 22:07:09 +00001842std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1843 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001844{
1845 return std::unique_ptr<IWorkload>();
1846}
1847
Derek Lamberti901ea112019-12-10 22:07:09 +00001848std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1849 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001850{
1851 return std::unique_ptr<IWorkload>();
1852}
1853
Derek Lamberti901ea112019-12-10 22:07:09 +00001854std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1855 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001856{
1857 return std::unique_ptr<IWorkload>();
1858}
1859
Derek Lamberti901ea112019-12-10 22:07:09 +00001860std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1861 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001862{
1863 return std::unique_ptr<IWorkload>();
1864}
1865
Derek Lamberti901ea112019-12-10 22:07:09 +00001866std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1867 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001868{
1869 return std::unique_ptr<IWorkload>();
1870}
1871
Derek Lamberti901ea112019-12-10 22:07:09 +00001872std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1873 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001874{
1875 return std::unique_ptr<IWorkload>();
1876}
1877
Derek Lamberti901ea112019-12-10 22:07:09 +00001878std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1879 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001880{
1881 return std::unique_ptr<IWorkload>();
1882}
1883
Derek Lamberti901ea112019-12-10 22:07:09 +00001884std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1885 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001886{
1887 return std::unique_ptr<IWorkload>();
1888}
1889
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001890std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1891 const WorkloadInfo& /*info*/) const
1892{
1893 return std::unique_ptr<IWorkload>();
1894}
1895
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001896std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001897 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1898 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001899{
1900 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001901}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001902
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001903std::unique_ptr<IWorkload> IWorkloadFactory::CreateUnidirectionalSequenceLstm(
1904 const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/,
1905 const WorkloadInfo& /*info*/) const
1906{
1907 return std::unique_ptr<IWorkload>();
1908}
1909
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001910} // namepsace armnn