blob: 9c47a19208c98b65c66c9ec9594faf31941f94b3 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
Sadik Armagana097d2a2021-11-24 15:47:28 +000011#include <armnn/backends/IBackendInternal.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010012#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000013#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000014#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010015#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010016#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000017
Colm Donelan0c479742021-12-10 12:43:54 +000018#include <armnn/backends/WorkloadFactory.hpp>
19#include <armnn/backends/TensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020
David Beck111b5d92018-11-12 14:59:37 +000021#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000022
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
telsoa01c577f2c2018-08-31 09:22:23 +010026namespace
27{
Finn Williams3e54d032020-10-22 16:53:35 +010028using LayerList = std::list<Layer*>;
29using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010030
David Beck29c75de2018-10-23 13:35:58 +010031const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
32{
33 if (!type)
34 {
35 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010036 }
37
Matthew Sloyan81beae32021-07-13 19:46:11 +010038 return TensorInfo(info.GetShape(),
39 type.value(),
40 info.GetQuantizationScale(),
41 info.GetQuantizationOffset(),
42 info.IsConstant());
telsoa01c577f2c2018-08-31 09:22:23 +010043}
44
David Beck29c75de2018-10-23 13:35:58 +010045} // anonymous namespace
46
Sadik Armagana097d2a2021-11-24 15:47:28 +000047inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
48{
49 if (!weightsType)
50 {
51 return weightsType;
52 }
53
54 switch(weightsType.value())
55 {
56 case armnn::DataType::BFloat16:
57 case armnn::DataType::Float16:
58 case armnn::DataType::Float32:
59 return weightsType;
60 case armnn::DataType::QAsymmS8:
61 case armnn::DataType::QAsymmU8:
62 case armnn::DataType::QSymmS8:
63 case armnn::DataType::QSymmS16:
64 return armnn::DataType::Signed32;
65 default:
66 ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
67 }
68 return armnn::EmptyOptional();
69}
70
71
Sadik Armagan045f6be2020-09-10 13:37:32 +010072bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
73 const IConnectableLayer& connectableLayer,
74 Optional<DataType> dataType,
75 std::string& outReasonIfUnsupported,
76 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000077{
David Beck33f0ae02018-10-18 15:13:56 +010078 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000079 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010080 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010081
David Beck111b5d92018-11-12 14:59:37 +000082 auto const& backendRegistry = BackendRegistryInstance();
83 if (!backendRegistry.IsBackendRegistered(backendId))
84 {
85 std::stringstream ss;
86 ss << connectableLayer.GetName() << " is not supported on " << backendId
87 << " because this backend is not registered.";
88
89 outReasonIfUnsupported = ss.str();
90 return false;
91 }
92
93 auto backendFactory = backendRegistry.GetFactory(backendId);
94 auto backendObject = backendFactory();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000095 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
David Beck33f0ae02018-10-18 15:13:56 +010096
telsoa014fcda012018-03-09 14:13:49 +000097 switch(layer.GetType())
98 {
99 case LayerType::Activation:
100 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100101 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000102 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100103 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000104 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100105 OverrideDataType(input, dataType),
106 OverrideDataType(output, dataType),
107 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100108 reason);
telsoa014fcda012018-03-09 14:13:49 +0000109 break;
110 }
111 case LayerType::Addition:
112 {
113 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
114 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
115 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000116 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100117 OverrideDataType(input0, dataType),
118 OverrideDataType(input1, dataType),
119 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100120 reason);
telsoa014fcda012018-03-09 14:13:49 +0000121 break;
122 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100123 case LayerType::ArgMinMax:
124 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100125 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +0100126 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
127
128 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
129 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000130 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100131 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000132 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100133 descriptor,
134 reason);
135 break;
136 }
telsoa014fcda012018-03-09 14:13:49 +0000137 case LayerType::BatchNormalization:
138 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100139 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000140 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100141 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
142 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
143 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
144 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
145 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000146 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100147 OverrideDataType(input, dataType),
148 OverrideDataType(output, dataType),
149 OverrideDataType(mean, dataType),
150 OverrideDataType(var, dataType),
151 OverrideDataType(beta, dataType),
152 OverrideDataType(gamma, dataType),
153 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100154 reason);
telsoa014fcda012018-03-09 14:13:49 +0000155 break;
156 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000157 case LayerType::BatchToSpaceNd:
158 {
159 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
160 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100161 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000162
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000163 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
164 OverrideDataType(output, dataType),
165 cLayer->GetParameters(),
166 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000167 break;
168 }
mathad01b392e982021-04-07 12:07:30 +0100169 case LayerType::Cast:
170 {
171 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
172 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
173
174 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
175 OverrideDataType(output, dataType),
176 reason);
177 break;
178 }
Simon Obute51f67772021-09-03 15:50:13 +0100179 case LayerType::ChannelShuffle:
180 {
181 auto cLayer = PolymorphicDowncast<const ChannelShuffleLayer*>(&layer);
182
183 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
184 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
185
186 const ChannelShuffleDescriptor descriptor = cLayer->GetParameters();
187
188 result = layerSupportObject.IsChannelShuffleSupported(OverrideDataType(input, dataType),
189 OverrideDataType(output, dataType),
190 descriptor,
191 reason);
192 break;
193 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100194 case LayerType::Comparison:
195 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100196 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100197
198 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
199 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
200 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
201
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000202 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
203 OverrideDataType(input1, dataType),
204 OverrideDataType(output, DataType::Boolean),
205 cLayer->GetParameters(),
206 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100207 break;
208 }
telsoa014fcda012018-03-09 14:13:49 +0000209 case LayerType::Constant:
210 {
211 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000212 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100213 break;
214 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000215 case LayerType::ConvertBf16ToFp32:
216 {
217 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
218 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000219 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000220 break;
221 }
telsoa01c577f2c2018-08-31 09:22:23 +0100222 case LayerType::ConvertFp16ToFp32:
223 {
224 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
225 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000226 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100227 break;
228 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000229 case LayerType::ConvertFp32ToBf16:
230 {
231 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
232 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000233 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000234 break;
235 }
telsoa01c577f2c2018-08-31 09:22:23 +0100236 case LayerType::ConvertFp32ToFp16:
237 {
238 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
239 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000240 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000241 break;
242 }
243 case LayerType::Convolution2d:
244 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100245 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100246
247 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
248 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100249 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100250 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
surmeh013537c2c2018-05-18 16:31:43 +0100251
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100252 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100253
arovir01a6824102018-08-28 17:40:45 +0100254 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100255 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100256 if (descriptor.m_BiasEnabled)
257 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100258 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100259 }
260
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000261 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100262 input,
263 output,
264 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100265 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100266 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100267 reason);
telsoa014fcda012018-03-09 14:13:49 +0000268 break;
269 }
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100270 case LayerType::Convolution3d:
271 {
272 auto cLayer = PolymorphicDowncast<const Convolution3dLayer*>(&layer);
273
274 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
275 dataType);
276 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100277
278 ARMNN_ASSERT_MSG(layer.GetInputSlot(1).GetConnection(),
279 "Convolution3dLayer: Weights should be connected as a Constant Layer.");
280 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
281 dataType);
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100282
283 const Convolution3dDescriptor& descriptor = cLayer->GetParameters();
284
285 // Construct optional biases object based on the value of m_BiasEnabled
286 Optional<TensorInfo> biases;
287 if (descriptor.m_BiasEnabled)
288 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100289 biases = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
290 GetBiasTypeFromWeightsType(dataType));
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100291 }
292
293 result = layerSupportObject.IsConvolution3dSupported(
294 input,
295 output,
296 descriptor,
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100297 weights,
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100298 biases,
299 reason);
300 break;
301 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000302 case LayerType::Debug:
303 {
304 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
305 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
306
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000307 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000308 OverrideDataType(output, dataType),
309 reason);
310 break;
311 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100312 case LayerType::DepthToSpace:
313 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100314 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100315
316 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
317 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
318
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000319 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100320 OverrideDataType(output, dataType),
321 cLayer->GetParameters(),
322 reason);
323 break;
324 }
telsoa014fcda012018-03-09 14:13:49 +0000325 case LayerType::DepthwiseConvolution2d:
326 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100327 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100328 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
329 dataType);
330 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100331 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100332
telsoa01c577f2c2018-08-31 09:22:23 +0100333 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100334
335 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100336 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100337 if (descriptor.m_BiasEnabled)
338 {
David Beck5eec11d2018-10-04 15:43:17 +0100339 biases =
340 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100341 }
telsoa01c577f2c2018-08-31 09:22:23 +0100342
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000343 result = layerSupportObject.IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100344 input,
345 output,
346 descriptor,
347 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100348 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100349 reason);
telsoa014fcda012018-03-09 14:13:49 +0000350 break;
351 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000352 case LayerType::Dequantize:
353 {
354 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
355 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
356
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000357 result = layerSupportObject.IsDequantizeSupported(input,
358 OverrideDataType(output, dataType),
359 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000360 break;
361 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000362 case LayerType::DetectionPostProcess:
363 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100364 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000365 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
366 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
367 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
368
369 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
370 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
371 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
372 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
373
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000374 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000375 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
376 scores,
377 anchors,
378 detectionBoxes,
379 detectionClasses,
380 detectionScores,
381 numDetections,
382 descriptor,
383 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000384 break;
385 }
josh minor4a3c6102020-01-06 16:40:46 -0600386 case LayerType::ElementwiseUnary:
387 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100388 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600389
390 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
391 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
392
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000393 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
394 OverrideDataType(output, dataType),
395 cLayer->GetParameters(),
396 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600397 break;
398 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100399 case LayerType::Fill:
400 {
401 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
402 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
403 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
404 const FillDescriptor& descriptor = cLayer->GetParameters();
405
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000406 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100407 OverrideDataType(input, dataType),
408 OverrideDataType(output, dataType),
409 descriptor,
410 reason);
411 break;
412 }
telsoa014fcda012018-03-09 14:13:49 +0000413 case LayerType::FakeQuantization:
414 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100415 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000416 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000417 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
418 cLayer->GetParameters(),
419 reason);
telsoa014fcda012018-03-09 14:13:49 +0000420 break;
421 }
422 case LayerType::Floor:
423 {
424 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
425 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000426 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
427 OverrideDataType(output, dataType),
428 reason);
telsoa014fcda012018-03-09 14:13:49 +0000429 break;
430 }
431 case LayerType::FullyConnected:
432 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100433 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000434 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100435 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000436
437 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
438 TensorInfo weightsInfo;
439 const TensorInfo* weightsInfoPtr = nullptr;
440
Matthew Sloyan81beae32021-07-13 19:46:11 +0100441 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000442 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100443
444 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000445 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000446 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100447 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
448 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
449 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
450
telsoa01c577f2c2018-08-31 09:22:23 +0100451 if (descriptor.m_BiasEnabled)
452 {
Matthew Sloyan81beae32021-07-13 19:46:11 +0100453 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
454 biasInfoPtr = &biasInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100455 }
456 else
457 {
458 // If biases are not enabled pass a dummy tensorinfo for the validation
459 switch(input.GetDataType())
460 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000461 case DataType::BFloat16:
462 {
463 biasInfoPtr = &dummyBFloat16Bias;
464 break;
465 }
telsoa01c577f2c2018-08-31 09:22:23 +0100466 case DataType::Float16:
467 {
468 biasInfoPtr = &dummyFloat16Bias;
469 break;
470 }
471 case DataType::Float32:
472 {
473 biasInfoPtr = &dummyFloat32Bias;
474 break;
475 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000476 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000477 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000478 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000479 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100480 {
481 biasInfoPtr = &dummyQA8Bias;
482 break;
483 }
484 default:
485 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100486 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100487 }
488 }
489 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000490 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100491 OverrideDataType(input, dataType),
492 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000493 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100494 *biasInfoPtr,
495 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100496 reason);
telsoa014fcda012018-03-09 14:13:49 +0000497 break;
498 }
narpra01b89b05f2019-01-16 09:53:09 +0000499 case LayerType::Gather:
500 {
501 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
502 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
503 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100504 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
505 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000506 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
507 input1,
508 OverrideDataType(output, dataType),
509 descriptor,
510 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000511 break;
512 }
telsoa014fcda012018-03-09 14:13:49 +0000513 case LayerType::Input:
514 {
515 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000516 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000517 break;
518 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100519 case LayerType::InstanceNormalization:
520 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100521 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100522 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
523
524 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
525 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
526
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000527 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100528 OverrideDataType(input, dataType),
529 OverrideDataType(output, dataType),
530 descriptor,
531 reason);
532 break;
533 }
telsoa014fcda012018-03-09 14:13:49 +0000534 case LayerType::L2Normalization:
535 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100536 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100537 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
538
telsoa014fcda012018-03-09 14:13:49 +0000539 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100540 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100541
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000542 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100543 OverrideDataType(input, dataType),
544 OverrideDataType(output, dataType),
545 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100546 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100547 break;
548 }
James Conroyaba90cd2020-11-06 16:28:18 +0000549 case LayerType::LogicalBinary:
550 {
551 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
552
553 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
554 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
555 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
556
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000557 result = layerSupportObject.IsLogicalBinarySupported(input0,
558 input1,
559 output,
560 cLayer->GetParameters(),
561 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000562 break;
563 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100564 case LayerType::LogSoftmax:
565 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100566 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100567
568 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
569 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
570
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000571 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
572 OverrideDataType(output, dataType),
573 cLayer->GetParameters(),
574 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100575 break;
576 }
telsoa01c577f2c2018-08-31 09:22:23 +0100577 case LayerType::Lstm:
578 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100579 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100580 const LstmDescriptor& descriptor = cLayer->GetParameters();
581
582 // All inputs.
583 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
584 dataType);
585 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
586 dataType);
587 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
588 dataType);
589 // All outputs
590 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
591 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
592 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
593 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
594
595 // Basic parameters
596 const TensorInfo& inputToForgetWeights
597 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
598 const TensorInfo& inputToCellWeights
599 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
600 const TensorInfo& inputToOutputWeights
601 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
602 const TensorInfo& recurrentToForgetWeights
603 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
604 const TensorInfo& recurrentToCellWeights
605 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
606 const TensorInfo& recurrentToOutputWeights
607 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
608 const TensorInfo& forgetGateBias
609 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
610 const TensorInfo& cellBias
611 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
612 const TensorInfo& outputGateBias
613 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
614
Jan Eilersd01a83c2019-07-03 18:20:40 +0100615 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100616
Jan Eilersd01a83c2019-07-03 18:20:40 +0100617 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
618 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
619 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
620 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
621 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
622 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
623 paramsInfo.m_ForgetGateBias = &forgetGateBias;
624 paramsInfo.m_CellBias = &cellBias;
625 paramsInfo.m_OutputGateBias = &outputGateBias;
626
627
628 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100629 TensorInfo optInputToInputWeights;
630 TensorInfo optRecurrentToInputWeights;
631 TensorInfo optCellToInputWeights;
632 TensorInfo optInputGateBias;
633 TensorInfo optProjectionWeights;
634 TensorInfo optProjectionBias;
635 TensorInfo optCellToForgetWeights;
636 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100637 TensorInfo optInputLayerNormWeights;
638 TensorInfo optForgetLayerNormWeights;
639 TensorInfo optCellLayerNormWeights;
640 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100641
642 if(!descriptor.m_CifgEnabled)
643 {
644 optInputToInputWeights =
645 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100646 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100647
648 optRecurrentToInputWeights =
649 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100650 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100651 optInputGateBias =
652 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100653 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100654 }
655
656 if(descriptor.m_ProjectionEnabled)
657 {
658 optProjectionWeights =
659 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100660 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100661 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
662 {
663 optProjectionBias =
664 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100665 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100666 }
667 }
668
669 if(descriptor.m_PeepholeEnabled)
670 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100671 if(!descriptor.m_CifgEnabled)
672 {
673 optCellToInputWeights =
674 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
675 dataType);
676 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
677 }
telsoa01c577f2c2018-08-31 09:22:23 +0100678 optCellToForgetWeights =
679 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100680 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100681 optCellToOutputWeights =
682 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100683 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100684 }
685
Jan Eilers38e05bd2019-06-26 13:10:09 +0100686 if(descriptor.m_LayerNormEnabled)
687 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100688 if (!descriptor.m_CifgEnabled)
689 {
690 optInputLayerNormWeights = OverrideDataType(
691 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
692 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
693 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100694
695 optForgetLayerNormWeights = OverrideDataType(
696 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100697 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100698
699 optCellLayerNormWeights = OverrideDataType(
700 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100701 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100702
703 optOutputLayerNormWeights = OverrideDataType(
704 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100705 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100706 }
707
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000708 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100709 input,
710 outputStateIn,
711 cellStateIn,
712 scratchBuffer,
713 outputStateOut,
714 cellStateOut,
715 output,
716 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100717 paramsInfo,
718 reason);
telsoa014fcda012018-03-09 14:13:49 +0000719 break;
720 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000721 case LayerType::Maximum:
722 {
723 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
724 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
725 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
726
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000727 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
728 OverrideDataType(input1, dataType),
729 OverrideDataType(output, dataType),
730 reason);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000731 break;
732 }
narpra01b89b05f2019-01-16 09:53:09 +0000733 case LayerType::MemCopy:
734 {
735 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
736 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000737
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000738 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
739 OverrideDataType(output, dataType),
740 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000741 break;
742 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100743 case LayerType::MemImport:
744 {
745 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
746 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
747
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000748 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
749 OverrideDataType(output, dataType),
750 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100751 break;
752 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100753 case LayerType::Merge:
754 {
755 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
756 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
757 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
758
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000759 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
760 OverrideDataType(input1, dataType),
761 OverrideDataType(output, dataType),
762 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100763 break;
764 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100765 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000766 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100767 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000768
telsoa01c577f2c2018-08-31 09:22:23 +0100769 // Get vector of all inputs.
770 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000771 {
telsoa01c577f2c2018-08-31 09:22:23 +0100772 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000773 };
Finn Williams3e54d032020-10-22 16:53:35 +0100774
775 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
776 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100777 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000778
telsoa01c577f2c2018-08-31 09:22:23 +0100779 auto getTensorInfoPtr = [](const TensorInfo& info)
780 {
781 return &info;
782 };
Finn Williams3e54d032020-10-22 16:53:35 +0100783
784 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
785 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100786 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000787
Nikhil Raj8599a412018-11-19 14:51:07 +0000788 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
789
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000790 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100791
792
telsoa014fcda012018-03-09 14:13:49 +0000793 break;
794 }
795 case LayerType::Multiplication:
796 {
797 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
798 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100799 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000800 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100801 OverrideDataType(input0, dataType),
802 OverrideDataType(input1, dataType),
803 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100804 reason);
telsoa014fcda012018-03-09 14:13:49 +0000805 break;
806 }
807 case LayerType::Normalization:
808 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100809 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000810 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
811 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000812 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
813 OverrideDataType(output, dataType),
814 cLayer->GetParameters(),
815 reason);
telsoa014fcda012018-03-09 14:13:49 +0000816 break;
817 }
818 case LayerType::Output:
819 {
820 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000821 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000822 break;
823 }
824 case LayerType::Permute:
825 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100826 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000827 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
828 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000829 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
830 OverrideDataType(output, dataType),
831 cLayer->GetParameters(),
832 reason);
telsoa014fcda012018-03-09 14:13:49 +0000833 break;
834 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100835 case LayerType::Pad:
836 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100837 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100838 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
839 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000840 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100841 OverrideDataType(input, dataType),
842 OverrideDataType(output, dataType),
843 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100844 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100845 break;
846 }
telsoa014fcda012018-03-09 14:13:49 +0000847 case LayerType::Pooling2d:
848 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100849 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000850 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
851 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000852 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
853 OverrideDataType(output, dataType),
854 cLayer->GetParameters(),
855 reason);
telsoa014fcda012018-03-09 14:13:49 +0000856 break;
857 }
Tamás Nyíri7b885b32021-10-26 14:47:57 +0100858 case LayerType::Pooling3d:
859 {
860 auto cLayer = PolymorphicDowncast<const Pooling3dLayer*>(&layer);
861 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
862 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
863 result = layerSupportObject.IsPooling3dSupported(OverrideDataType(input, dataType),
864 OverrideDataType(output, dataType),
865 cLayer->GetParameters(),
866 reason);
867 break;
868 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000869 case LayerType::PreCompiled:
870 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100871 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000872 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000873 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
874 cLayer->GetParameters(),
875 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000876 break;
877 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000878 case LayerType::Quantize:
879 {
880 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
881 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000882 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000883 break;
884 }
James Conroy586a9aa2020-03-20 08:49:33 +0000885 case LayerType::QLstm:
886 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100887 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000888 const QLstmDescriptor& descriptor = cLayer->GetParameters();
889
890 // Inputs
891 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
892 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
893 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
894
895 // Outputs
896 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
897 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
898 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
899
900 // Lstm parameters
901 LstmInputParamsInfo paramsInfo;
902
903 // Basic parameters
Matthew Bentham6f24b1a2021-06-29 15:18:32 +0100904 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() != nullptr);
905 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() != nullptr);
906 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() != nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000907 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
908 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
909 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
910
911 paramsInfo.m_RecurrentToForgetWeights =
912 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
913 paramsInfo.m_RecurrentToCellWeights =
914 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
915 paramsInfo.m_RecurrentToOutputWeights =
916 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
917
918 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
919 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
920 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
921
922 if(!descriptor.m_CifgEnabled)
923 {
924 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
925 paramsInfo.m_RecurrentToInputWeights =
926 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
927 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
928 }
929
930 if(descriptor.m_ProjectionEnabled)
931 {
932 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100933
934 // Projection bias is optional even if projection is enabled
935 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
936 {
937 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
938 }
James Conroy586a9aa2020-03-20 08:49:33 +0000939 }
940
941 if(descriptor.m_PeepholeEnabled)
942 {
943 if (!descriptor.m_CifgEnabled)
944 {
945 paramsInfo.m_CellToInputWeights =
946 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
947 }
948
949 paramsInfo.m_CellToForgetWeights =
950 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
951 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
952 }
953
954 if(descriptor.m_LayerNormEnabled)
955 {
956 if (!descriptor.m_CifgEnabled)
957 {
958 paramsInfo.m_InputLayerNormWeights =
959 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
960 }
961
962 paramsInfo.m_ForgetLayerNormWeights =
963 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
964 paramsInfo.m_CellLayerNormWeights =
965 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
966 paramsInfo.m_OutputLayerNormWeights =
967 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
968 }
969
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000970 result = layerSupportObject.IsQLstmSupported(input,
971 previousOutputIn,
972 previousCellStateIn,
973 outputStateOut,
974 cellStateOut,
975 output,
976 descriptor,
977 paramsInfo,
978 reason);
James Conroy586a9aa2020-03-20 08:49:33 +0000979 break;
980 }
James Conroyee18dc82019-07-17 11:27:46 +0100981 case LayerType::QuantizedLstm:
982 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100983 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +0100984
985 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100986 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
987 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
988 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100989
990 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100991 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
992 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100993
994 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100995 QuantizedLstmInputParamsInfo paramsInfo;
996
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100997 paramsInfo.m_InputToInputWeights =
998 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
999 paramsInfo.m_InputToForgetWeights =
1000 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
1001 paramsInfo.m_InputToCellWeights =
1002 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
1003 paramsInfo.m_InputToOutputWeights =
1004 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001005
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001006 paramsInfo.m_RecurrentToInputWeights =
1007 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
1008 paramsInfo.m_RecurrentToForgetWeights =
1009 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
1010 paramsInfo.m_RecurrentToCellWeights =
1011 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
1012 paramsInfo.m_RecurrentToOutputWeights =
1013 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001014
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001015 paramsInfo.m_InputGateBias =
1016 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
1017 paramsInfo.m_ForgetGateBias =
1018 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
1019 paramsInfo.m_CellBias =
1020 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
1021 paramsInfo.m_OutputGateBias =
1022 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +01001023
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001024 result = layerSupportObject.IsQuantizedLstmSupported(input,
1025 previousCellStateIn,
1026 previousOutputIn,
1027 cellStateOut,
1028 output,
1029 paramsInfo,
1030 reason);
James Conroyee18dc82019-07-17 11:27:46 +01001031 break;
1032 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001033 case LayerType::Division:
1034 {
1035 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1036 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1037 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001038 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001039 OverrideDataType(input0, dataType),
1040 OverrideDataType(input1, dataType),
1041 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001042 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001043 break;
1044 }
Finn Williams2605b232020-06-10 15:53:46 +01001045 case LayerType::Rank:
1046 {
1047 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1048 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001049 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
1050 OverrideDataType(output, dataType),
1051 reason);
Finn Williams2605b232020-06-10 15:53:46 +01001052 break;
1053 }
telsoa014fcda012018-03-09 14:13:49 +00001054 case LayerType::Reshape:
1055 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001056 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001057 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +00001058 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001059 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
1060 OverrideDataType(output, dataType),
1061 cLayer->GetParameters(),
1062 reason);
telsoa014fcda012018-03-09 14:13:49 +00001063 break;
1064 }
Teresa Charlina9075df2019-06-27 15:41:57 +01001065 case LayerType::Resize:
1066 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001067 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +01001068 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +01001069 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001070 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1071 OverrideDataType(output, dataType),
1072 cLayer->GetParameters(),
1073 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +01001074 break;
1075 }
Keith Davis3ae3f972021-05-21 16:33:48 +01001076 case LayerType::Shape:
1077 {
1078 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1079 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1080
1081 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1082 OverrideDataType(output, dataType),
1083 reason);
1084 break;
1085 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001086 case LayerType::Slice:
1087 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001088 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001089
1090 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1091 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1092
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001093 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1094 OverrideDataType(output, dataType),
1095 cLayer->GetParameters(),
1096 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001097 break;
1098 }
telsoa014fcda012018-03-09 14:13:49 +00001099 case LayerType::Softmax:
1100 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001101 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001102 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001103 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001104 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1105 OverrideDataType(output, dataType),
1106 cLayer->GetParameters(),
1107 reason);
telsoa014fcda012018-03-09 14:13:49 +00001108 break;
1109 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001110 case LayerType::SpaceToBatchNd:
1111 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001112 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001113 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1114 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001115 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1116 OverrideDataType(output, dataType),
1117 cLayer->GetParameters(),
1118 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001119 break;
1120 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001121 case LayerType::SpaceToDepth:
1122 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001123 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001124
1125 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1126 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1127
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001128 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1129 OverrideDataType(output, dataType),
1130 cLayer->GetParameters(),
1131 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001132 break;
1133 }
telsoa014fcda012018-03-09 14:13:49 +00001134 case LayerType::Splitter:
1135 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001136 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001137 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001138
1139 // Get vector of all outputs.
1140 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1141 {
1142 return OverrideDataType(slot.GetTensorInfo(), dataType);
1143 };
Finn Williams3e54d032020-10-22 16:53:35 +01001144 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1145 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001146 std::vector<TensorInfo> outputs(beginI, endI);
1147
1148 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1149
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001150 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1151 outputPtrs,
1152 cLayer->GetParameters(),
1153 reason);
telsoa014fcda012018-03-09 14:13:49 +00001154 break;
1155 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001156 case LayerType::Stack:
1157 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001158 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001159
1160 // Get vector of all inputs.
1161 auto getTensorInfo = [&dataType](const InputSlot& slot)
1162 {
1163 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1164 };
Finn Williams3e54d032020-10-22 16:53:35 +01001165 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1166 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001167 std::vector<TensorInfo> inputs(beginI, endI);
1168
1169 auto getTensorInfoPtr = [](const TensorInfo& info)
1170 {
1171 return &info;
1172 };
Finn Williams3e54d032020-10-22 16:53:35 +01001173 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1174 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001175 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1176
1177 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1178
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001179 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001180
1181 break;
1182 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001183 case LayerType::StandIn:
1184 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001185 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001186
1187 // Get vector of all inputs.
1188 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1189 {
1190 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1191 };
1192 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1193 {
1194 return OverrideDataType(slot.GetTensorInfo(), dataType);
1195 };
Finn Williams3e54d032020-10-22 16:53:35 +01001196 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1197 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001198 std::vector<TensorInfo> inputs(beginI, endI);
1199
Finn Williams3e54d032020-10-22 16:53:35 +01001200 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1201 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001202 std::vector<TensorInfo> outputs(beginO, endO);
1203
1204
1205 auto getTensorInfoPtr = [](const TensorInfo& info)
1206 {
1207 return &info;
1208 };
Finn Williams3e54d032020-10-22 16:53:35 +01001209 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1210 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001211 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1212
Finn Williams3e54d032020-10-22 16:53:35 +01001213 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1214 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001215 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1216
1217
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001218 result = layerSupportObject.IsStandInSupported(inputPtrs,
1219 outputPtrs,
1220 cLayer->GetParameters(),
1221 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001222 break;
1223 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001224 case LayerType::StridedSlice:
1225 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001226 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001227 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1228 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001229 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1230 OverrideDataType(output, dataType),
1231 cLayer->GetParameters(),
1232 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001233 break;
1234 }
David Beckc2044fe2018-09-05 15:00:38 +01001235 case LayerType::Subtraction:
1236 {
1237 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1238 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1239 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001240 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001241 OverrideDataType(input0, dataType),
1242 OverrideDataType(input1, dataType),
1243 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001244 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001245 break;
1246 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001247 case LayerType::Switch:
1248 {
1249 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1250 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1251 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1252 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001253 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1254 OverrideDataType(input1, dataType),
1255 OverrideDataType(output0, dataType),
1256 OverrideDataType(output1, dataType),
1257 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001258 break;
1259 }
narpra0132b90462018-09-13 11:07:48 +01001260 case LayerType::Mean:
1261 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001262 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001263 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1264 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001265 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001266 OverrideDataType(input, dataType),
1267 OverrideDataType(output, dataType),
1268 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001269 reason);
narpra0132b90462018-09-13 11:07:48 +01001270 break;
1271 }
kevmay0190539692018-11-29 08:40:19 +00001272 case LayerType::Minimum:
1273 {
1274 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1275 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1276 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001277 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1278 OverrideDataType(input1, dataType),
1279 OverrideDataType(output, dataType),
1280 reason);
kevmay0190539692018-11-29 08:40:19 +00001281 break;
1282 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001283 case LayerType::Prelu:
1284 {
1285 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1286 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1287 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001288 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1289 OverrideDataType(alpha, dataType),
1290 OverrideDataType(output, dataType),
1291 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001292 break;
1293 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001294 case LayerType::Transpose:
1295 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001296 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001297 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1298 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001299 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1300 OverrideDataType(output, dataType),
1301 cLayer->GetParameters(),
1302 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001303 break;
1304 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001305 case LayerType::TransposeConvolution2d:
1306 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001307 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001308
1309 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1310 dataType);
1311 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1312
1313 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1314
1315 Optional<TensorInfo> biases;
1316 if (descriptor.m_BiasEnabled)
1317 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001318 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001319 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1320 GetBiasTypeFromWeightsType(dataType));
1321 }
1322
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001323 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001324 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1325
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001326 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1327 output,
1328 descriptor,
1329 weights,
1330 biases,
1331 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001332
1333 break;
1334 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001335 case LayerType::Reduce:
1336 {
1337 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1338 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1339 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1340
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001341 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1342 OverrideDataType(output, dataType),
1343 cLayer->GetParameters(),
1344 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001345 break;
1346 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001347 case LayerType::UnidirectionalSequenceLstm:
1348 {
1349 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1350 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1351
1352 // All inputs.
1353 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1354 dataType);
1355 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
1356 dataType);
1357 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
1358 dataType);
1359 // Outputs
1360 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1361
1362 // Basic parameters
1363 const TensorInfo& inputToForgetWeights
1364 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1365 const TensorInfo& inputToCellWeights
1366 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1367 const TensorInfo& inputToOutputWeights
1368 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1369 const TensorInfo& recurrentToForgetWeights
1370 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1371 const TensorInfo& recurrentToCellWeights
1372 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1373 const TensorInfo& recurrentToOutputWeights
1374 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1375 const TensorInfo& forgetGateBias
1376 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1377 const TensorInfo& cellBias
1378 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1379 const TensorInfo& outputGateBias
1380 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1381
1382 LstmInputParamsInfo paramsInfo;
1383
1384 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1385 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1386 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1387 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1388 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1389 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1390 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1391 paramsInfo.m_CellBias = &cellBias;
1392 paramsInfo.m_OutputGateBias = &outputGateBias;
1393
1394 // Optional parameters
1395 TensorInfo optInputToInputWeights;
1396 TensorInfo optRecurrentToInputWeights;
1397 TensorInfo optCellToInputWeights;
1398 TensorInfo optInputGateBias;
1399 TensorInfo optProjectionWeights;
1400 TensorInfo optProjectionBias;
1401 TensorInfo optCellToForgetWeights;
1402 TensorInfo optCellToOutputWeights;
1403 TensorInfo optInputLayerNormWeights;
1404 TensorInfo optForgetLayerNormWeights;
1405 TensorInfo optCellLayerNormWeights;
1406 TensorInfo optOutputLayerNormWeights;
1407
1408 if(!descriptor.m_CifgEnabled)
1409 {
1410 optInputToInputWeights =
1411 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1412 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1413
1414 optRecurrentToInputWeights =
1415 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1416 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1417 optInputGateBias =
1418 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1419 paramsInfo.m_InputGateBias = &optInputGateBias;
1420 }
1421
1422 if(descriptor.m_ProjectionEnabled)
1423 {
1424 optProjectionWeights =
1425 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1426 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1427 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1428 {
1429 optProjectionBias =
1430 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1431 paramsInfo.m_ProjectionBias = &optProjectionBias;
1432 }
1433 }
1434
1435 if(descriptor.m_PeepholeEnabled)
1436 {
1437 if(!descriptor.m_CifgEnabled)
1438 {
1439 optCellToInputWeights =
1440 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1441 dataType);
1442 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1443 }
1444 optCellToForgetWeights =
1445 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1446 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1447 optCellToOutputWeights =
1448 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1449 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1450 }
1451
1452 if(descriptor.m_LayerNormEnabled)
1453 {
1454 if (!descriptor.m_CifgEnabled)
1455 {
1456 optInputLayerNormWeights = OverrideDataType(
1457 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1458 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1459 }
1460
1461 optForgetLayerNormWeights = OverrideDataType(
1462 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1463 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1464
1465 optCellLayerNormWeights = OverrideDataType(
1466 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1467 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1468
1469 optOutputLayerNormWeights = OverrideDataType(
1470 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1471 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1472 }
1473
1474 Optional<TensorInfo> hiddenStateOut;
1475 Optional<TensorInfo> cellStateOut;
1476
1477 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1478 outputStateIn,
1479 cellStateIn,
1480 output,
1481 hiddenStateOut,
1482 cellStateOut,
1483 descriptor,
1484 paramsInfo,
1485 reason);
1486 break;
1487 }
telsoa014fcda012018-03-09 14:13:49 +00001488 default:
1489 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001490 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001491 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001492 result = false;
1493 break;
1494 }
1495 }
telsoa014fcda012018-03-09 14:13:49 +00001496 return result;
1497}
1498
Sadik Armagan045f6be2020-09-10 13:37:32 +01001499bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1500 const IConnectableLayer& connectableLayer,
1501 Optional<DataType> dataType,
1502 std::string& outReasonIfUnsupported)
1503{
1504 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1505}
1506
David Beckdcb751f2018-10-03 11:42:42 +01001507bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001508 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001509 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001510{
Jan Eilersbb446e52020-04-02 13:56:54 +01001511 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001512 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1513}
1514
1515// TODO merge with defaulted modelOptions above
1516bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1517 Optional<DataType> dataType,
1518 std::string& outReasonIfUnsupported,
1519 const ModelOptions& modelOptions)
1520{
1521 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1522 return IsLayerConfigurationSupported(layer->GetBackendId(),
1523 connectableLayer,
1524 dataType,
1525 outReasonIfUnsupported,
1526 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001527}
1528
Sadik Armagan04a72972020-09-14 15:44:18 +01001529bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1530 const IConnectableLayer& connectableLayer,
1531 Optional<DataType> dataType,
1532 std::string& outReasonIfUnsupported,
1533 const ModelOptions& modelOptions)
1534{
1535 return IsLayerConfigurationSupported(backendId,
1536 connectableLayer,
1537 dataType,
1538 outReasonIfUnsupported,
1539 modelOptions);
1540}
1541
Derek Lamberti901ea112019-12-10 22:07:09 +00001542std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1543 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001544{
1545 return std::unique_ptr<IWorkload>();
1546}
1547
Derek Lamberti901ea112019-12-10 22:07:09 +00001548std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1549 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001550{
1551 return std::unique_ptr<IWorkload>();
1552}
1553
Derek Lamberti901ea112019-12-10 22:07:09 +00001554std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1555 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001556{
1557 return std::unique_ptr<IWorkload>();
1558}
1559
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001560std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001561 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001562{
1563 return std::unique_ptr<IWorkload>();
1564}
1565
Derek Lamberti901ea112019-12-10 22:07:09 +00001566std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1567 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001568{
1569 return std::unique_ptr<IWorkload>();
1570}
1571
mathad01b392e982021-04-07 12:07:30 +01001572std::unique_ptr<IWorkload> IWorkloadFactory::CreateCast(const CastQueueDescriptor& /*descriptor*/,
1573 const WorkloadInfo& /*info*/) const
1574{
1575 return std::unique_ptr<IWorkload>();
1576}
1577
Simon Obute51f67772021-09-03 15:50:13 +01001578std::unique_ptr<IWorkload> IWorkloadFactory::CreateChannelShuffle(const ChannelShuffleQueueDescriptor& /*descriptor*/,
1579 const WorkloadInfo& /*info*/) const
1580{
1581 return std::unique_ptr<IWorkload>();
1582}
1583
Derek Lamberti901ea112019-12-10 22:07:09 +00001584std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1585 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001586{
1587 return std::unique_ptr<IWorkload>();
1588}
1589
Derek Lamberti901ea112019-12-10 22:07:09 +00001590std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1591 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001592{
1593 return std::unique_ptr<IWorkload>();
1594}
1595
Derek Lamberti901ea112019-12-10 22:07:09 +00001596std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1597 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001598{
1599 return std::unique_ptr<IWorkload>();
1600}
1601
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001602std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1603 const WorkloadInfo& /*info*/) const
1604{
1605 return std::unique_ptr<IWorkload>();
1606}
1607
Derek Lamberti901ea112019-12-10 22:07:09 +00001608std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1609 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001610{
1611 return std::unique_ptr<IWorkload>();
1612}
1613
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001614std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1615 const WorkloadInfo& /*info*/) const
1616{
1617 return std::unique_ptr<IWorkload>();
1618}
1619
Derek Lamberti901ea112019-12-10 22:07:09 +00001620std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1621 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001622{
1623 return std::unique_ptr<IWorkload>();
1624}
1625
Derek Lamberti901ea112019-12-10 22:07:09 +00001626std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1627 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001628{
1629 return std::unique_ptr<IWorkload>();
1630}
1631
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001632std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution3d(const Convolution3dQueueDescriptor& /*descriptor*/,
1633 const WorkloadInfo& /*info*/) const
1634{
1635 return std::unique_ptr<IWorkload>();
1636}
1637
Derek Lamberti901ea112019-12-10 22:07:09 +00001638std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1639 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001640{
1641 return std::unique_ptr<IWorkload>();
1642}
1643
Derek Lamberti901ea112019-12-10 22:07:09 +00001644std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1645 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001646{
1647 return std::unique_ptr<IWorkload>();
1648}
1649
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001650std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001651 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001652{
1653 return std::unique_ptr<IWorkload>();
1654}
1655
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001656std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001657 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001658{
1659 return std::unique_ptr<IWorkload>();
1660}
1661
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001662std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001663 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001664{
1665 return std::unique_ptr<IWorkload>();
1666}
1667
Derek Lamberti901ea112019-12-10 22:07:09 +00001668std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1669 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001670{
1671 return std::unique_ptr<IWorkload>();
1672}
1673
josh minor4a3c6102020-01-06 16:40:46 -06001674std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1675 const WorkloadInfo& /*info*/) const
1676{
1677 return std::unique_ptr<IWorkload>();
1678}
1679
Derek Lamberti901ea112019-12-10 22:07:09 +00001680std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1681 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001682{
1683 return std::unique_ptr<IWorkload>();
1684}
1685
Ryan OSheaec6c6802020-06-05 17:17:06 +01001686std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
1687 const WorkloadInfo& /*info*/) const
1688{
1689 return std::unique_ptr<IWorkload>();
1690}
1691
Derek Lamberti901ea112019-12-10 22:07:09 +00001692std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1693 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001694{
1695 return std::unique_ptr<IWorkload>();
1696}
1697
Derek Lamberti901ea112019-12-10 22:07:09 +00001698std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1699 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001700{
1701 return std::unique_ptr<IWorkload>();
1702}
1703
Derek Lamberti901ea112019-12-10 22:07:09 +00001704std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1705 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001706{
1707 return std::unique_ptr<IWorkload>();
1708}
1709
Kevin Mayce5045a2019-10-02 14:07:47 +01001710std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001711 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1712 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001713{
1714 return std::unique_ptr<IWorkload>();
1715}
1716
Derek Lamberti901ea112019-12-10 22:07:09 +00001717std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1718 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001719{
1720 return std::unique_ptr<IWorkload>();
1721}
1722
James Conroyaba90cd2020-11-06 16:28:18 +00001723std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& /*desc*/,
1724 const WorkloadInfo& /*info*/) const
1725{
1726 return std::unique_ptr<IWorkload>();
1727}
1728
1729std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1730 const WorkloadInfo& /*info*/) const
1731{
1732 return std::unique_ptr<IWorkload>();
1733}
1734
Derek Lamberti901ea112019-12-10 22:07:09 +00001735std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1736 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001737{
1738 return std::unique_ptr<IWorkload>();
1739}
1740
Derek Lamberti901ea112019-12-10 22:07:09 +00001741std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1742 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001743{
1744 return std::unique_ptr<IWorkload>();
1745}
1746
Derek Lamberti901ea112019-12-10 22:07:09 +00001747std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1748 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001749{
1750 return std::unique_ptr<IWorkload>();
1751}
1752
Derek Lamberti901ea112019-12-10 22:07:09 +00001753std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1754 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001755{
1756 return std::unique_ptr<IWorkload>();
1757}
1758
Derek Lamberti901ea112019-12-10 22:07:09 +00001759std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1760 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001761{
1762 return std::unique_ptr<IWorkload>();
1763}
1764
Derek Lamberti901ea112019-12-10 22:07:09 +00001765std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1766 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001767{
1768 return std::unique_ptr<IWorkload>();
1769}
1770
Derek Lamberti901ea112019-12-10 22:07:09 +00001771std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1772 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001773{
1774 return std::unique_ptr<IWorkload>();
1775}
1776
Derek Lamberti901ea112019-12-10 22:07:09 +00001777std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1778 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001779{
1780 return std::unique_ptr<IWorkload>();
1781}
1782
Derek Lamberti901ea112019-12-10 22:07:09 +00001783std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1784 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001785{
1786 return std::unique_ptr<IWorkload>();
1787}
1788
Derek Lamberti901ea112019-12-10 22:07:09 +00001789std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1790 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001791{
1792 return std::unique_ptr<IWorkload>();
1793}
1794
Derek Lamberti901ea112019-12-10 22:07:09 +00001795std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1796 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001797{
1798 return std::unique_ptr<IWorkload>();
1799}
1800
Derek Lamberti901ea112019-12-10 22:07:09 +00001801std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1802 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001803{
1804 return std::unique_ptr<IWorkload>();
1805}
1806
Derek Lamberti901ea112019-12-10 22:07:09 +00001807std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001808 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001809{
1810 return std::unique_ptr<IWorkload>();
1811}
1812
Derek Lamberti901ea112019-12-10 22:07:09 +00001813std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1814 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001815{
1816 return std::unique_ptr<IWorkload>();
1817}
1818
Tamás Nyíri7b885b32021-10-26 14:47:57 +01001819std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling3d(const Pooling3dQueueDescriptor& /*descriptor*/,
1820 const WorkloadInfo& /*info*/) const
1821{
1822 return std::unique_ptr<IWorkload>();
1823}
1824
Derek Lamberti901ea112019-12-10 22:07:09 +00001825std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1826 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001827{
1828 return std::unique_ptr<IWorkload>();
1829}
1830
Derek Lamberti901ea112019-12-10 22:07:09 +00001831std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1832 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001833{
1834 return std::unique_ptr<IWorkload>();
1835}
1836
Derek Lamberti901ea112019-12-10 22:07:09 +00001837std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1838 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001839{
1840 return std::unique_ptr<IWorkload>();
1841}
1842
James Conroy586a9aa2020-03-20 08:49:33 +00001843std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
1844 const WorkloadInfo& /*info*/) const
1845{
1846 return std::unique_ptr<IWorkload>();
1847}
1848
Derek Lamberti901ea112019-12-10 22:07:09 +00001849std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1850 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001851{
1852 return std::unique_ptr<IWorkload>();
1853}
Finn Williams2605b232020-06-10 15:53:46 +01001854std::unique_ptr<IWorkload> IWorkloadFactory::CreateRank(const RankQueueDescriptor& /*descriptor*/,
1855 const WorkloadInfo& /*info*/) const
1856{
1857 return std::unique_ptr<IWorkload>();
1858}
James Conroyee18dc82019-07-17 11:27:46 +01001859
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001860std::unique_ptr<IWorkload> IWorkloadFactory::CreateReduce(const ReduceQueueDescriptor& /*descriptor*/,
1861 const WorkloadInfo& /*info*/) const
1862{
1863 return std::unique_ptr<IWorkload>();
1864}
1865
Derek Lamberti901ea112019-12-10 22:07:09 +00001866std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1867 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001868{
1869 return std::unique_ptr<IWorkload>();
1870}
1871
Derek Lamberti901ea112019-12-10 22:07:09 +00001872std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1873 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001874{
1875 return std::unique_ptr<IWorkload>();
1876}
1877
Keith Davis3ae3f972021-05-21 16:33:48 +01001878std::unique_ptr<IWorkload> IWorkloadFactory::CreateShape(const ShapeQueueDescriptor& /*descriptor*/,
1879 const WorkloadInfo& /*info*/) const
1880{
1881 return std::unique_ptr<IWorkload>();
1882}
1883
Derek Lamberti901ea112019-12-10 22:07:09 +00001884std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1885 const WorkloadInfo& /*info*/) const
1886{
1887 return std::unique_ptr<IWorkload>();
1888}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001889
Derek Lamberti901ea112019-12-10 22:07:09 +00001890std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1891 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001892{
1893 return std::unique_ptr<IWorkload>();
1894}
1895
Derek Lamberti901ea112019-12-10 22:07:09 +00001896std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1897 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001898{
1899 return std::unique_ptr<IWorkload>();
1900}
1901
Derek Lamberti901ea112019-12-10 22:07:09 +00001902std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1903 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001904{
1905 return std::unique_ptr<IWorkload>();
1906}
1907
Derek Lamberti901ea112019-12-10 22:07:09 +00001908std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1909 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001910{
1911 return std::unique_ptr<IWorkload>();
1912}
1913
Derek Lamberti901ea112019-12-10 22:07:09 +00001914std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1915 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001916{
1917 return std::unique_ptr<IWorkload>();
1918}
1919
Derek Lamberti901ea112019-12-10 22:07:09 +00001920std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1921 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001922{
1923 return std::unique_ptr<IWorkload>();
1924}
1925
Derek Lamberti901ea112019-12-10 22:07:09 +00001926std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1927 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001928{
1929 return std::unique_ptr<IWorkload>();
1930}
1931
Derek Lamberti901ea112019-12-10 22:07:09 +00001932std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1933 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001934{
1935 return std::unique_ptr<IWorkload>();
1936}
1937
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001938std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1939 const WorkloadInfo& /*info*/) const
1940{
1941 return std::unique_ptr<IWorkload>();
1942}
1943
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001944std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001945 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1946 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001947{
1948 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001949}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001950
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001951std::unique_ptr<IWorkload> IWorkloadFactory::CreateUnidirectionalSequenceLstm(
1952 const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/,
1953 const WorkloadInfo& /*info*/) const
1954{
1955 return std::unique_ptr<IWorkload>();
1956}
1957
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001958} // namepsace armnn