blob: 3b7f3a0f1f4483a15c26ce418c1c40548738735a [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010011#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000017#include <backendsCommon/WorkloadFactory.hpp>
James Conroy1f58f032021-04-27 17:13:27 +010018#include <backendsCommon/TensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000019
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
David Beck111b5d92018-11-12 14:59:37 +000022#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000023
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
telsoa01c577f2c2018-08-31 09:22:23 +010027namespace
28{
Finn Williams3e54d032020-10-22 16:53:35 +010029using LayerList = std::list<Layer*>;
30using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010031
David Beck29c75de2018-10-23 13:35:58 +010032const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
33{
34 if (!type)
35 {
36 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010037 }
38
Matthew Sloyan81beae32021-07-13 19:46:11 +010039 return TensorInfo(info.GetShape(),
40 type.value(),
41 info.GetQuantizationScale(),
42 info.GetQuantizationOffset(),
43 info.IsConstant());
telsoa01c577f2c2018-08-31 09:22:23 +010044}
45
David Beck29c75de2018-10-23 13:35:58 +010046} // anonymous namespace
47
Sadik Armagan045f6be2020-09-10 13:37:32 +010048bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
49 const IConnectableLayer& connectableLayer,
50 Optional<DataType> dataType,
51 std::string& outReasonIfUnsupported,
52 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000053{
David Beck33f0ae02018-10-18 15:13:56 +010054 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000055 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010056 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010057
David Beck111b5d92018-11-12 14:59:37 +000058 auto const& backendRegistry = BackendRegistryInstance();
59 if (!backendRegistry.IsBackendRegistered(backendId))
60 {
61 std::stringstream ss;
62 ss << connectableLayer.GetName() << " is not supported on " << backendId
63 << " because this backend is not registered.";
64
65 outReasonIfUnsupported = ss.str();
66 return false;
67 }
68
69 auto backendFactory = backendRegistry.GetFactory(backendId);
70 auto backendObject = backendFactory();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000071 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
David Beck33f0ae02018-10-18 15:13:56 +010072
telsoa014fcda012018-03-09 14:13:49 +000073 switch(layer.GetType())
74 {
75 case LayerType::Activation:
76 {
Jan Eilersbb446e52020-04-02 13:56:54 +010077 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +000078 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010079 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000080 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010081 OverrideDataType(input, dataType),
82 OverrideDataType(output, dataType),
83 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010084 reason);
telsoa014fcda012018-03-09 14:13:49 +000085 break;
86 }
87 case LayerType::Addition:
88 {
89 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
90 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
91 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000092 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010093 OverrideDataType(input0, dataType),
94 OverrideDataType(input1, dataType),
95 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010096 reason);
telsoa014fcda012018-03-09 14:13:49 +000097 break;
98 }
Nikhil Rajee391d52019-09-05 17:50:44 +010099 case LayerType::ArgMinMax:
100 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100101 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +0100102 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
103
104 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000106 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100107 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000108 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100109 descriptor,
110 reason);
111 break;
112 }
telsoa014fcda012018-03-09 14:13:49 +0000113 case LayerType::BatchNormalization:
114 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100115 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000122 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100137 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000138
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000139 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000143 break;
144 }
mathad01b392e982021-04-07 12:07:30 +0100145 case LayerType::Cast:
146 {
147 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
148 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
149
150 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
151 OverrideDataType(output, dataType),
152 reason);
153 break;
154 }
Simon Obute51f67772021-09-03 15:50:13 +0100155 case LayerType::ChannelShuffle:
156 {
157 auto cLayer = PolymorphicDowncast<const ChannelShuffleLayer*>(&layer);
158
159 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
160 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161
162 const ChannelShuffleDescriptor descriptor = cLayer->GetParameters();
163
164 result = layerSupportObject.IsChannelShuffleSupported(OverrideDataType(input, dataType),
165 OverrideDataType(output, dataType),
166 descriptor,
167 reason);
168 break;
169 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100170 case LayerType::Comparison:
171 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100172 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100173
174 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
175 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
176 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
177
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000178 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
179 OverrideDataType(input1, dataType),
180 OverrideDataType(output, DataType::Boolean),
181 cLayer->GetParameters(),
182 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100183 break;
184 }
telsoa014fcda012018-03-09 14:13:49 +0000185 case LayerType::Constant:
186 {
187 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000188 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 break;
190 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000191 case LayerType::ConvertBf16ToFp32:
192 {
193 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
194 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000195 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000196 break;
197 }
telsoa01c577f2c2018-08-31 09:22:23 +0100198 case LayerType::ConvertFp16ToFp32:
199 {
200 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
201 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000202 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100203 break;
204 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000205 case LayerType::ConvertFp32ToBf16:
206 {
207 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
208 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000209 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000210 break;
211 }
telsoa01c577f2c2018-08-31 09:22:23 +0100212 case LayerType::ConvertFp32ToFp16:
213 {
214 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
215 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000216 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000217 break;
218 }
219 case LayerType::Convolution2d:
220 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100221 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100222
223 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
224 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100225 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100226 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
surmeh013537c2c2018-05-18 16:31:43 +0100227
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100228 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100229
arovir01a6824102018-08-28 17:40:45 +0100230 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100231 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100232 if (descriptor.m_BiasEnabled)
233 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100234 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100235 }
236
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000237 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100238 input,
239 output,
240 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100241 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100242 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100243 reason);
telsoa014fcda012018-03-09 14:13:49 +0000244 break;
245 }
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100246 case LayerType::Convolution3d:
247 {
248 auto cLayer = PolymorphicDowncast<const Convolution3dLayer*>(&layer);
249
250 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
251 dataType);
252 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
253 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
254
255 const Convolution3dDescriptor& descriptor = cLayer->GetParameters();
256
257 // Construct optional biases object based on the value of m_BiasEnabled
258 Optional<TensorInfo> biases;
259 if (descriptor.m_BiasEnabled)
260 {
261 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
262 }
263
264 result = layerSupportObject.IsConvolution3dSupported(
265 input,
266 output,
267 descriptor,
268 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
269 biases,
270 reason);
271 break;
272 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000273 case LayerType::Debug:
274 {
275 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
276 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
277
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000278 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000279 OverrideDataType(output, dataType),
280 reason);
281 break;
282 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100283 case LayerType::DepthToSpace:
284 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100285 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100286
287 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
288 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
289
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000290 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100291 OverrideDataType(output, dataType),
292 cLayer->GetParameters(),
293 reason);
294 break;
295 }
telsoa014fcda012018-03-09 14:13:49 +0000296 case LayerType::DepthwiseConvolution2d:
297 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100298 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100299 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
300 dataType);
301 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100302 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100303
telsoa01c577f2c2018-08-31 09:22:23 +0100304 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100305
306 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100307 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100308 if (descriptor.m_BiasEnabled)
309 {
David Beck5eec11d2018-10-04 15:43:17 +0100310 biases =
311 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100312 }
telsoa01c577f2c2018-08-31 09:22:23 +0100313
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000314 result = layerSupportObject.IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100315 input,
316 output,
317 descriptor,
318 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100319 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100320 reason);
telsoa014fcda012018-03-09 14:13:49 +0000321 break;
322 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000323 case LayerType::Dequantize:
324 {
325 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
326 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
327
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000328 result = layerSupportObject.IsDequantizeSupported(input,
329 OverrideDataType(output, dataType),
330 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000331 break;
332 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000333 case LayerType::DetectionPostProcess:
334 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100335 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000336 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
337 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
338 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
339
340 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
341 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
342 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
343 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
344
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000345 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000346 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
347 scores,
348 anchors,
349 detectionBoxes,
350 detectionClasses,
351 detectionScores,
352 numDetections,
353 descriptor,
354 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000355 break;
356 }
josh minor4a3c6102020-01-06 16:40:46 -0600357 case LayerType::ElementwiseUnary:
358 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100359 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600360
361 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
362 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
363
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000364 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
365 OverrideDataType(output, dataType),
366 cLayer->GetParameters(),
367 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600368 break;
369 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100370 case LayerType::Fill:
371 {
372 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
373 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
374 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
375 const FillDescriptor& descriptor = cLayer->GetParameters();
376
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000377 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100378 OverrideDataType(input, dataType),
379 OverrideDataType(output, dataType),
380 descriptor,
381 reason);
382 break;
383 }
telsoa014fcda012018-03-09 14:13:49 +0000384 case LayerType::FakeQuantization:
385 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100386 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000387 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000388 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
389 cLayer->GetParameters(),
390 reason);
telsoa014fcda012018-03-09 14:13:49 +0000391 break;
392 }
393 case LayerType::Floor:
394 {
395 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
396 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000397 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
398 OverrideDataType(output, dataType),
399 reason);
telsoa014fcda012018-03-09 14:13:49 +0000400 break;
401 }
402 case LayerType::FullyConnected:
403 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100404 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000405 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100406 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000407
408 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
409 TensorInfo weightsInfo;
410 const TensorInfo* weightsInfoPtr = nullptr;
411
Matthew Sloyan81beae32021-07-13 19:46:11 +0100412 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000413 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100414
415 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000416 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000417 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100418 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
419 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
420 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
421
telsoa01c577f2c2018-08-31 09:22:23 +0100422 if (descriptor.m_BiasEnabled)
423 {
Matthew Sloyan81beae32021-07-13 19:46:11 +0100424 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
425 biasInfoPtr = &biasInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100426 }
427 else
428 {
429 // If biases are not enabled pass a dummy tensorinfo for the validation
430 switch(input.GetDataType())
431 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000432 case DataType::BFloat16:
433 {
434 biasInfoPtr = &dummyBFloat16Bias;
435 break;
436 }
telsoa01c577f2c2018-08-31 09:22:23 +0100437 case DataType::Float16:
438 {
439 biasInfoPtr = &dummyFloat16Bias;
440 break;
441 }
442 case DataType::Float32:
443 {
444 biasInfoPtr = &dummyFloat32Bias;
445 break;
446 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000447 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000448 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000449 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000450 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100451 {
452 biasInfoPtr = &dummyQA8Bias;
453 break;
454 }
455 default:
456 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100457 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100458 }
459 }
460 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000461 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100462 OverrideDataType(input, dataType),
463 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000464 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100465 *biasInfoPtr,
466 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100467 reason);
telsoa014fcda012018-03-09 14:13:49 +0000468 break;
469 }
narpra01b89b05f2019-01-16 09:53:09 +0000470 case LayerType::Gather:
471 {
472 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
473 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
474 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100475 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
476 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000477 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
478 input1,
479 OverrideDataType(output, dataType),
480 descriptor,
481 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000482 break;
483 }
telsoa014fcda012018-03-09 14:13:49 +0000484 case LayerType::Input:
485 {
486 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000487 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000488 break;
489 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100490 case LayerType::InstanceNormalization:
491 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100492 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100493 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
494
495 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
496 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
497
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000498 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100499 OverrideDataType(input, dataType),
500 OverrideDataType(output, dataType),
501 descriptor,
502 reason);
503 break;
504 }
telsoa014fcda012018-03-09 14:13:49 +0000505 case LayerType::L2Normalization:
506 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100507 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100508 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
509
telsoa014fcda012018-03-09 14:13:49 +0000510 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100511 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100512
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000513 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100514 OverrideDataType(input, dataType),
515 OverrideDataType(output, dataType),
516 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100517 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100518 break;
519 }
James Conroyaba90cd2020-11-06 16:28:18 +0000520 case LayerType::LogicalBinary:
521 {
522 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
523
524 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
525 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
526 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
527
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000528 result = layerSupportObject.IsLogicalBinarySupported(input0,
529 input1,
530 output,
531 cLayer->GetParameters(),
532 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000533 break;
534 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100535 case LayerType::LogSoftmax:
536 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100537 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100538
539 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
540 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
541
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000542 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
543 OverrideDataType(output, dataType),
544 cLayer->GetParameters(),
545 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100546 break;
547 }
telsoa01c577f2c2018-08-31 09:22:23 +0100548 case LayerType::Lstm:
549 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100550 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100551 const LstmDescriptor& descriptor = cLayer->GetParameters();
552
553 // All inputs.
554 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
555 dataType);
556 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
557 dataType);
558 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
559 dataType);
560 // All outputs
561 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
562 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
563 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
564 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
565
566 // Basic parameters
567 const TensorInfo& inputToForgetWeights
568 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
569 const TensorInfo& inputToCellWeights
570 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
571 const TensorInfo& inputToOutputWeights
572 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
573 const TensorInfo& recurrentToForgetWeights
574 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
575 const TensorInfo& recurrentToCellWeights
576 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
577 const TensorInfo& recurrentToOutputWeights
578 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
579 const TensorInfo& forgetGateBias
580 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
581 const TensorInfo& cellBias
582 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
583 const TensorInfo& outputGateBias
584 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
585
Jan Eilersd01a83c2019-07-03 18:20:40 +0100586 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100587
Jan Eilersd01a83c2019-07-03 18:20:40 +0100588 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
589 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
590 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
591 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
592 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
593 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
594 paramsInfo.m_ForgetGateBias = &forgetGateBias;
595 paramsInfo.m_CellBias = &cellBias;
596 paramsInfo.m_OutputGateBias = &outputGateBias;
597
598
599 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100600 TensorInfo optInputToInputWeights;
601 TensorInfo optRecurrentToInputWeights;
602 TensorInfo optCellToInputWeights;
603 TensorInfo optInputGateBias;
604 TensorInfo optProjectionWeights;
605 TensorInfo optProjectionBias;
606 TensorInfo optCellToForgetWeights;
607 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100608 TensorInfo optInputLayerNormWeights;
609 TensorInfo optForgetLayerNormWeights;
610 TensorInfo optCellLayerNormWeights;
611 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100612
613 if(!descriptor.m_CifgEnabled)
614 {
615 optInputToInputWeights =
616 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100617 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100618
619 optRecurrentToInputWeights =
620 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100621 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100622 optInputGateBias =
623 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100624 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100625 }
626
627 if(descriptor.m_ProjectionEnabled)
628 {
629 optProjectionWeights =
630 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100631 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100632 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
633 {
634 optProjectionBias =
635 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100636 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100637 }
638 }
639
640 if(descriptor.m_PeepholeEnabled)
641 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100642 if(!descriptor.m_CifgEnabled)
643 {
644 optCellToInputWeights =
645 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
646 dataType);
647 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
648 }
telsoa01c577f2c2018-08-31 09:22:23 +0100649 optCellToForgetWeights =
650 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100651 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100652 optCellToOutputWeights =
653 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100654 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100655 }
656
Jan Eilers38e05bd2019-06-26 13:10:09 +0100657 if(descriptor.m_LayerNormEnabled)
658 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100659 if (!descriptor.m_CifgEnabled)
660 {
661 optInputLayerNormWeights = OverrideDataType(
662 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
663 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
664 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100665
666 optForgetLayerNormWeights = OverrideDataType(
667 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100668 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100669
670 optCellLayerNormWeights = OverrideDataType(
671 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100672 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100673
674 optOutputLayerNormWeights = OverrideDataType(
675 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100676 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100677 }
678
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000679 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100680 input,
681 outputStateIn,
682 cellStateIn,
683 scratchBuffer,
684 outputStateOut,
685 cellStateOut,
686 output,
687 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100688 paramsInfo,
689 reason);
telsoa014fcda012018-03-09 14:13:49 +0000690 break;
691 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000692 case LayerType::Maximum:
693 {
694 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
695 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
696 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
697
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000698 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
699 OverrideDataType(input1, dataType),
700 OverrideDataType(output, dataType),
701 reason);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000702 break;
703 }
narpra01b89b05f2019-01-16 09:53:09 +0000704 case LayerType::MemCopy:
705 {
706 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
707 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000708
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000709 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
710 OverrideDataType(output, dataType),
711 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000712 break;
713 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100714 case LayerType::MemImport:
715 {
716 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
717 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
718
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000719 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
720 OverrideDataType(output, dataType),
721 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100722 break;
723 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100724 case LayerType::Merge:
725 {
726 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
727 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
728 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
729
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000730 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
731 OverrideDataType(input1, dataType),
732 OverrideDataType(output, dataType),
733 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100734 break;
735 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100736 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000737 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100738 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000739
telsoa01c577f2c2018-08-31 09:22:23 +0100740 // Get vector of all inputs.
741 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000742 {
telsoa01c577f2c2018-08-31 09:22:23 +0100743 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000744 };
Finn Williams3e54d032020-10-22 16:53:35 +0100745
746 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
747 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100748 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000749
telsoa01c577f2c2018-08-31 09:22:23 +0100750 auto getTensorInfoPtr = [](const TensorInfo& info)
751 {
752 return &info;
753 };
Finn Williams3e54d032020-10-22 16:53:35 +0100754
755 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
756 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100757 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000758
Nikhil Raj8599a412018-11-19 14:51:07 +0000759 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
760
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000761 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100762
763
telsoa014fcda012018-03-09 14:13:49 +0000764 break;
765 }
766 case LayerType::Multiplication:
767 {
768 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
769 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100770 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000771 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100772 OverrideDataType(input0, dataType),
773 OverrideDataType(input1, dataType),
774 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100775 reason);
telsoa014fcda012018-03-09 14:13:49 +0000776 break;
777 }
778 case LayerType::Normalization:
779 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100780 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000781 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
782 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000783 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
784 OverrideDataType(output, dataType),
785 cLayer->GetParameters(),
786 reason);
telsoa014fcda012018-03-09 14:13:49 +0000787 break;
788 }
789 case LayerType::Output:
790 {
791 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000792 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000793 break;
794 }
795 case LayerType::Permute:
796 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100797 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000798 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
799 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000800 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
801 OverrideDataType(output, dataType),
802 cLayer->GetParameters(),
803 reason);
telsoa014fcda012018-03-09 14:13:49 +0000804 break;
805 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100806 case LayerType::Pad:
807 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100808 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100809 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
810 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000811 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100812 OverrideDataType(input, dataType),
813 OverrideDataType(output, dataType),
814 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100815 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100816 break;
817 }
telsoa014fcda012018-03-09 14:13:49 +0000818 case LayerType::Pooling2d:
819 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100820 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000821 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
822 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000823 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
824 OverrideDataType(output, dataType),
825 cLayer->GetParameters(),
826 reason);
telsoa014fcda012018-03-09 14:13:49 +0000827 break;
828 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000829 case LayerType::PreCompiled:
830 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100831 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000832 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000833 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
834 cLayer->GetParameters(),
835 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000836 break;
837 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000838 case LayerType::Quantize:
839 {
840 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
841 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000842 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000843 break;
844 }
James Conroy586a9aa2020-03-20 08:49:33 +0000845 case LayerType::QLstm:
846 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100847 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000848 const QLstmDescriptor& descriptor = cLayer->GetParameters();
849
850 // Inputs
851 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
852 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
853 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
854
855 // Outputs
856 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
857 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
858 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
859
860 // Lstm parameters
861 LstmInputParamsInfo paramsInfo;
862
863 // Basic parameters
Matthew Bentham6f24b1a2021-06-29 15:18:32 +0100864 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() != nullptr);
865 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() != nullptr);
866 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() != nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000867 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
868 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
869 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
870
871 paramsInfo.m_RecurrentToForgetWeights =
872 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
873 paramsInfo.m_RecurrentToCellWeights =
874 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
875 paramsInfo.m_RecurrentToOutputWeights =
876 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
877
878 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
879 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
880 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
881
882 if(!descriptor.m_CifgEnabled)
883 {
884 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
885 paramsInfo.m_RecurrentToInputWeights =
886 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
887 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
888 }
889
890 if(descriptor.m_ProjectionEnabled)
891 {
892 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100893
894 // Projection bias is optional even if projection is enabled
895 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
896 {
897 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
898 }
James Conroy586a9aa2020-03-20 08:49:33 +0000899 }
900
901 if(descriptor.m_PeepholeEnabled)
902 {
903 if (!descriptor.m_CifgEnabled)
904 {
905 paramsInfo.m_CellToInputWeights =
906 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
907 }
908
909 paramsInfo.m_CellToForgetWeights =
910 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
911 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
912 }
913
914 if(descriptor.m_LayerNormEnabled)
915 {
916 if (!descriptor.m_CifgEnabled)
917 {
918 paramsInfo.m_InputLayerNormWeights =
919 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
920 }
921
922 paramsInfo.m_ForgetLayerNormWeights =
923 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
924 paramsInfo.m_CellLayerNormWeights =
925 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
926 paramsInfo.m_OutputLayerNormWeights =
927 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
928 }
929
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000930 result = layerSupportObject.IsQLstmSupported(input,
931 previousOutputIn,
932 previousCellStateIn,
933 outputStateOut,
934 cellStateOut,
935 output,
936 descriptor,
937 paramsInfo,
938 reason);
James Conroy586a9aa2020-03-20 08:49:33 +0000939 break;
940 }
James Conroyee18dc82019-07-17 11:27:46 +0100941 case LayerType::QuantizedLstm:
942 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100943 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +0100944
945 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100946 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
947 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
948 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100949
950 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100951 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
952 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100953
954 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100955 QuantizedLstmInputParamsInfo paramsInfo;
956
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100957 paramsInfo.m_InputToInputWeights =
958 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
959 paramsInfo.m_InputToForgetWeights =
960 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
961 paramsInfo.m_InputToCellWeights =
962 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
963 paramsInfo.m_InputToOutputWeights =
964 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100965
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100966 paramsInfo.m_RecurrentToInputWeights =
967 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
968 paramsInfo.m_RecurrentToForgetWeights =
969 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
970 paramsInfo.m_RecurrentToCellWeights =
971 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
972 paramsInfo.m_RecurrentToOutputWeights =
973 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100974
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100975 paramsInfo.m_InputGateBias =
976 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
977 paramsInfo.m_ForgetGateBias =
978 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
979 paramsInfo.m_CellBias =
980 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
981 paramsInfo.m_OutputGateBias =
982 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100983
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000984 result = layerSupportObject.IsQuantizedLstmSupported(input,
985 previousCellStateIn,
986 previousOutputIn,
987 cellStateOut,
988 output,
989 paramsInfo,
990 reason);
James Conroyee18dc82019-07-17 11:27:46 +0100991 break;
992 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100993 case LayerType::Division:
994 {
995 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
996 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
997 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000998 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100999 OverrideDataType(input0, dataType),
1000 OverrideDataType(input1, dataType),
1001 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001002 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001003 break;
1004 }
Finn Williams2605b232020-06-10 15:53:46 +01001005 case LayerType::Rank:
1006 {
1007 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1008 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001009 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
1010 OverrideDataType(output, dataType),
1011 reason);
Finn Williams2605b232020-06-10 15:53:46 +01001012 break;
1013 }
telsoa014fcda012018-03-09 14:13:49 +00001014 case LayerType::Reshape:
1015 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001016 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001017 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +00001018 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001019 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
1020 OverrideDataType(output, dataType),
1021 cLayer->GetParameters(),
1022 reason);
telsoa014fcda012018-03-09 14:13:49 +00001023 break;
1024 }
Teresa Charlina9075df2019-06-27 15:41:57 +01001025 case LayerType::Resize:
1026 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001027 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +01001028 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +01001029 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001030 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1031 OverrideDataType(output, dataType),
1032 cLayer->GetParameters(),
1033 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +01001034 break;
1035 }
Keith Davis3ae3f972021-05-21 16:33:48 +01001036 case LayerType::Shape:
1037 {
1038 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1039 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1040
1041 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1042 OverrideDataType(output, dataType),
1043 reason);
1044 break;
1045 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001046 case LayerType::Slice:
1047 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001048 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001049
1050 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1051 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1052
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001053 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1054 OverrideDataType(output, dataType),
1055 cLayer->GetParameters(),
1056 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001057 break;
1058 }
telsoa014fcda012018-03-09 14:13:49 +00001059 case LayerType::Softmax:
1060 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001061 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001062 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001063 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001064 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1065 OverrideDataType(output, dataType),
1066 cLayer->GetParameters(),
1067 reason);
telsoa014fcda012018-03-09 14:13:49 +00001068 break;
1069 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001070 case LayerType::SpaceToBatchNd:
1071 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001072 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001073 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1074 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001075 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1076 OverrideDataType(output, dataType),
1077 cLayer->GetParameters(),
1078 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001079 break;
1080 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001081 case LayerType::SpaceToDepth:
1082 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001083 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001084
1085 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1086 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1087
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001088 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1089 OverrideDataType(output, dataType),
1090 cLayer->GetParameters(),
1091 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001092 break;
1093 }
telsoa014fcda012018-03-09 14:13:49 +00001094 case LayerType::Splitter:
1095 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001096 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001097 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001098
1099 // Get vector of all outputs.
1100 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1101 {
1102 return OverrideDataType(slot.GetTensorInfo(), dataType);
1103 };
Finn Williams3e54d032020-10-22 16:53:35 +01001104 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1105 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001106 std::vector<TensorInfo> outputs(beginI, endI);
1107
1108 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1109
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001110 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1111 outputPtrs,
1112 cLayer->GetParameters(),
1113 reason);
telsoa014fcda012018-03-09 14:13:49 +00001114 break;
1115 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001116 case LayerType::Stack:
1117 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001118 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001119
1120 // Get vector of all inputs.
1121 auto getTensorInfo = [&dataType](const InputSlot& slot)
1122 {
1123 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1124 };
Finn Williams3e54d032020-10-22 16:53:35 +01001125 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1126 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001127 std::vector<TensorInfo> inputs(beginI, endI);
1128
1129 auto getTensorInfoPtr = [](const TensorInfo& info)
1130 {
1131 return &info;
1132 };
Finn Williams3e54d032020-10-22 16:53:35 +01001133 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1134 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001135 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1136
1137 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1138
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001139 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001140
1141 break;
1142 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001143 case LayerType::StandIn:
1144 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001145 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001146
1147 // Get vector of all inputs.
1148 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1149 {
1150 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1151 };
1152 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1153 {
1154 return OverrideDataType(slot.GetTensorInfo(), dataType);
1155 };
Finn Williams3e54d032020-10-22 16:53:35 +01001156 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1157 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001158 std::vector<TensorInfo> inputs(beginI, endI);
1159
Finn Williams3e54d032020-10-22 16:53:35 +01001160 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1161 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001162 std::vector<TensorInfo> outputs(beginO, endO);
1163
1164
1165 auto getTensorInfoPtr = [](const TensorInfo& info)
1166 {
1167 return &info;
1168 };
Finn Williams3e54d032020-10-22 16:53:35 +01001169 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1170 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001171 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1172
Finn Williams3e54d032020-10-22 16:53:35 +01001173 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1174 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001175 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1176
1177
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001178 result = layerSupportObject.IsStandInSupported(inputPtrs,
1179 outputPtrs,
1180 cLayer->GetParameters(),
1181 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001182 break;
1183 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001184 case LayerType::StridedSlice:
1185 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001186 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001187 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1188 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001189 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1190 OverrideDataType(output, dataType),
1191 cLayer->GetParameters(),
1192 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001193 break;
1194 }
David Beckc2044fe2018-09-05 15:00:38 +01001195 case LayerType::Subtraction:
1196 {
1197 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1198 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1199 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001200 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001201 OverrideDataType(input0, dataType),
1202 OverrideDataType(input1, dataType),
1203 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001204 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001205 break;
1206 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001207 case LayerType::Switch:
1208 {
1209 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1210 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1211 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1212 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001213 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1214 OverrideDataType(input1, dataType),
1215 OverrideDataType(output0, dataType),
1216 OverrideDataType(output1, dataType),
1217 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001218 break;
1219 }
narpra0132b90462018-09-13 11:07:48 +01001220 case LayerType::Mean:
1221 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001222 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001223 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1224 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001225 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001226 OverrideDataType(input, dataType),
1227 OverrideDataType(output, dataType),
1228 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001229 reason);
narpra0132b90462018-09-13 11:07:48 +01001230 break;
1231 }
kevmay0190539692018-11-29 08:40:19 +00001232 case LayerType::Minimum:
1233 {
1234 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1235 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1236 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001237 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1238 OverrideDataType(input1, dataType),
1239 OverrideDataType(output, dataType),
1240 reason);
kevmay0190539692018-11-29 08:40:19 +00001241 break;
1242 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001243 case LayerType::Prelu:
1244 {
1245 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1246 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1247 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001248 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1249 OverrideDataType(alpha, dataType),
1250 OverrideDataType(output, dataType),
1251 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001252 break;
1253 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001254 case LayerType::Transpose:
1255 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001256 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001257 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1258 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001259 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1260 OverrideDataType(output, dataType),
1261 cLayer->GetParameters(),
1262 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001263 break;
1264 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001265 case LayerType::TransposeConvolution2d:
1266 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001267 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001268
1269 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1270 dataType);
1271 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1272
1273 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1274
1275 Optional<TensorInfo> biases;
1276 if (descriptor.m_BiasEnabled)
1277 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001278 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001279 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1280 GetBiasTypeFromWeightsType(dataType));
1281 }
1282
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001283 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001284 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1285
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001286 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1287 output,
1288 descriptor,
1289 weights,
1290 biases,
1291 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001292
1293 break;
1294 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001295 case LayerType::Reduce:
1296 {
1297 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1298 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1299 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1300
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001301 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1302 OverrideDataType(output, dataType),
1303 cLayer->GetParameters(),
1304 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001305 break;
1306 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001307 case LayerType::UnidirectionalSequenceLstm:
1308 {
1309 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1310 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1311
1312 // All inputs.
1313 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1314 dataType);
1315 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
1316 dataType);
1317 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
1318 dataType);
1319 // Outputs
1320 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1321
1322 // Basic parameters
1323 const TensorInfo& inputToForgetWeights
1324 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1325 const TensorInfo& inputToCellWeights
1326 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1327 const TensorInfo& inputToOutputWeights
1328 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1329 const TensorInfo& recurrentToForgetWeights
1330 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1331 const TensorInfo& recurrentToCellWeights
1332 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1333 const TensorInfo& recurrentToOutputWeights
1334 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1335 const TensorInfo& forgetGateBias
1336 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1337 const TensorInfo& cellBias
1338 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1339 const TensorInfo& outputGateBias
1340 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1341
1342 LstmInputParamsInfo paramsInfo;
1343
1344 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1345 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1346 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1347 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1348 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1349 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1350 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1351 paramsInfo.m_CellBias = &cellBias;
1352 paramsInfo.m_OutputGateBias = &outputGateBias;
1353
1354 // Optional parameters
1355 TensorInfo optInputToInputWeights;
1356 TensorInfo optRecurrentToInputWeights;
1357 TensorInfo optCellToInputWeights;
1358 TensorInfo optInputGateBias;
1359 TensorInfo optProjectionWeights;
1360 TensorInfo optProjectionBias;
1361 TensorInfo optCellToForgetWeights;
1362 TensorInfo optCellToOutputWeights;
1363 TensorInfo optInputLayerNormWeights;
1364 TensorInfo optForgetLayerNormWeights;
1365 TensorInfo optCellLayerNormWeights;
1366 TensorInfo optOutputLayerNormWeights;
1367
1368 if(!descriptor.m_CifgEnabled)
1369 {
1370 optInputToInputWeights =
1371 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1372 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1373
1374 optRecurrentToInputWeights =
1375 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1376 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1377 optInputGateBias =
1378 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1379 paramsInfo.m_InputGateBias = &optInputGateBias;
1380 }
1381
1382 if(descriptor.m_ProjectionEnabled)
1383 {
1384 optProjectionWeights =
1385 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1386 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1387 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1388 {
1389 optProjectionBias =
1390 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1391 paramsInfo.m_ProjectionBias = &optProjectionBias;
1392 }
1393 }
1394
1395 if(descriptor.m_PeepholeEnabled)
1396 {
1397 if(!descriptor.m_CifgEnabled)
1398 {
1399 optCellToInputWeights =
1400 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1401 dataType);
1402 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1403 }
1404 optCellToForgetWeights =
1405 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1406 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1407 optCellToOutputWeights =
1408 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1409 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1410 }
1411
1412 if(descriptor.m_LayerNormEnabled)
1413 {
1414 if (!descriptor.m_CifgEnabled)
1415 {
1416 optInputLayerNormWeights = OverrideDataType(
1417 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1418 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1419 }
1420
1421 optForgetLayerNormWeights = OverrideDataType(
1422 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1423 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1424
1425 optCellLayerNormWeights = OverrideDataType(
1426 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1427 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1428
1429 optOutputLayerNormWeights = OverrideDataType(
1430 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1431 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1432 }
1433
1434 Optional<TensorInfo> hiddenStateOut;
1435 Optional<TensorInfo> cellStateOut;
1436
1437 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1438 outputStateIn,
1439 cellStateIn,
1440 output,
1441 hiddenStateOut,
1442 cellStateOut,
1443 descriptor,
1444 paramsInfo,
1445 reason);
1446 break;
1447 }
telsoa014fcda012018-03-09 14:13:49 +00001448 default:
1449 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001450 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001451 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001452 result = false;
1453 break;
1454 }
1455 }
telsoa014fcda012018-03-09 14:13:49 +00001456 return result;
1457}
1458
Sadik Armagan045f6be2020-09-10 13:37:32 +01001459bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1460 const IConnectableLayer& connectableLayer,
1461 Optional<DataType> dataType,
1462 std::string& outReasonIfUnsupported)
1463{
1464 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1465}
1466
David Beckdcb751f2018-10-03 11:42:42 +01001467bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001468 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001469 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001470{
Jan Eilersbb446e52020-04-02 13:56:54 +01001471 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001472 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1473}
1474
1475// TODO merge with defaulted modelOptions above
1476bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1477 Optional<DataType> dataType,
1478 std::string& outReasonIfUnsupported,
1479 const ModelOptions& modelOptions)
1480{
1481 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1482 return IsLayerConfigurationSupported(layer->GetBackendId(),
1483 connectableLayer,
1484 dataType,
1485 outReasonIfUnsupported,
1486 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001487}
1488
Sadik Armagan04a72972020-09-14 15:44:18 +01001489bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1490 const IConnectableLayer& connectableLayer,
1491 Optional<DataType> dataType,
1492 std::string& outReasonIfUnsupported,
1493 const ModelOptions& modelOptions)
1494{
1495 return IsLayerConfigurationSupported(backendId,
1496 connectableLayer,
1497 dataType,
1498 outReasonIfUnsupported,
1499 modelOptions);
1500}
1501
Derek Lamberti901ea112019-12-10 22:07:09 +00001502std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1503 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001504{
1505 return std::unique_ptr<IWorkload>();
1506}
1507
Derek Lamberti901ea112019-12-10 22:07:09 +00001508std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1509 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001510{
1511 return std::unique_ptr<IWorkload>();
1512}
1513
Derek Lamberti901ea112019-12-10 22:07:09 +00001514std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1515 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001516{
1517 return std::unique_ptr<IWorkload>();
1518}
1519
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001520std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001521 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001522{
1523 return std::unique_ptr<IWorkload>();
1524}
1525
Derek Lamberti901ea112019-12-10 22:07:09 +00001526std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1527 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001528{
1529 return std::unique_ptr<IWorkload>();
1530}
1531
mathad01b392e982021-04-07 12:07:30 +01001532std::unique_ptr<IWorkload> IWorkloadFactory::CreateCast(const CastQueueDescriptor& /*descriptor*/,
1533 const WorkloadInfo& /*info*/) const
1534{
1535 return std::unique_ptr<IWorkload>();
1536}
1537
Simon Obute51f67772021-09-03 15:50:13 +01001538std::unique_ptr<IWorkload> IWorkloadFactory::CreateChannelShuffle(const ChannelShuffleQueueDescriptor& /*descriptor*/,
1539 const WorkloadInfo& /*info*/) const
1540{
1541 return std::unique_ptr<IWorkload>();
1542}
1543
Derek Lamberti901ea112019-12-10 22:07:09 +00001544std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1545 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001546{
1547 return std::unique_ptr<IWorkload>();
1548}
1549
Derek Lamberti901ea112019-12-10 22:07:09 +00001550std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1551 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001552{
1553 return std::unique_ptr<IWorkload>();
1554}
1555
Derek Lamberti901ea112019-12-10 22:07:09 +00001556std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1557 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001558{
1559 return std::unique_ptr<IWorkload>();
1560}
1561
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001562std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1563 const WorkloadInfo& /*info*/) const
1564{
1565 return std::unique_ptr<IWorkload>();
1566}
1567
Derek Lamberti901ea112019-12-10 22:07:09 +00001568std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1569 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001570{
1571 return std::unique_ptr<IWorkload>();
1572}
1573
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001574std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1575 const WorkloadInfo& /*info*/) const
1576{
1577 return std::unique_ptr<IWorkload>();
1578}
1579
Derek Lamberti901ea112019-12-10 22:07:09 +00001580std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1581 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001582{
1583 return std::unique_ptr<IWorkload>();
1584}
1585
Derek Lamberti901ea112019-12-10 22:07:09 +00001586std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1587 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001588{
1589 return std::unique_ptr<IWorkload>();
1590}
1591
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001592std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution3d(const Convolution3dQueueDescriptor& /*descriptor*/,
1593 const WorkloadInfo& /*info*/) const
1594{
1595 return std::unique_ptr<IWorkload>();
1596}
1597
Derek Lamberti901ea112019-12-10 22:07:09 +00001598std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1599 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001600{
1601 return std::unique_ptr<IWorkload>();
1602}
1603
Derek Lamberti901ea112019-12-10 22:07:09 +00001604std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1605 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001606{
1607 return std::unique_ptr<IWorkload>();
1608}
1609
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001610std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001611 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001612{
1613 return std::unique_ptr<IWorkload>();
1614}
1615
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001616std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001617 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001618{
1619 return std::unique_ptr<IWorkload>();
1620}
1621
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001622std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001623 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001624{
1625 return std::unique_ptr<IWorkload>();
1626}
1627
Derek Lamberti901ea112019-12-10 22:07:09 +00001628std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1629 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001630{
1631 return std::unique_ptr<IWorkload>();
1632}
1633
josh minor4a3c6102020-01-06 16:40:46 -06001634std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1635 const WorkloadInfo& /*info*/) const
1636{
1637 return std::unique_ptr<IWorkload>();
1638}
1639
Derek Lamberti901ea112019-12-10 22:07:09 +00001640std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1641 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001642{
1643 return std::unique_ptr<IWorkload>();
1644}
1645
Ryan OSheaec6c6802020-06-05 17:17:06 +01001646std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
1647 const WorkloadInfo& /*info*/) const
1648{
1649 return std::unique_ptr<IWorkload>();
1650}
1651
Derek Lamberti901ea112019-12-10 22:07:09 +00001652std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1653 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001654{
1655 return std::unique_ptr<IWorkload>();
1656}
1657
Derek Lamberti901ea112019-12-10 22:07:09 +00001658std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1659 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001660{
1661 return std::unique_ptr<IWorkload>();
1662}
1663
Derek Lamberti901ea112019-12-10 22:07:09 +00001664std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1665 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001666{
1667 return std::unique_ptr<IWorkload>();
1668}
1669
Kevin Mayce5045a2019-10-02 14:07:47 +01001670std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001671 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1672 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001673{
1674 return std::unique_ptr<IWorkload>();
1675}
1676
Derek Lamberti901ea112019-12-10 22:07:09 +00001677std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1678 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001679{
1680 return std::unique_ptr<IWorkload>();
1681}
1682
James Conroyaba90cd2020-11-06 16:28:18 +00001683std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& /*desc*/,
1684 const WorkloadInfo& /*info*/) const
1685{
1686 return std::unique_ptr<IWorkload>();
1687}
1688
1689std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1690 const WorkloadInfo& /*info*/) const
1691{
1692 return std::unique_ptr<IWorkload>();
1693}
1694
Derek Lamberti901ea112019-12-10 22:07:09 +00001695std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1696 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001697{
1698 return std::unique_ptr<IWorkload>();
1699}
1700
Derek Lamberti901ea112019-12-10 22:07:09 +00001701std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1702 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001703{
1704 return std::unique_ptr<IWorkload>();
1705}
1706
Derek Lamberti901ea112019-12-10 22:07:09 +00001707std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1708 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001709{
1710 return std::unique_ptr<IWorkload>();
1711}
1712
Derek Lamberti901ea112019-12-10 22:07:09 +00001713std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1714 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001715{
1716 return std::unique_ptr<IWorkload>();
1717}
1718
Derek Lamberti901ea112019-12-10 22:07:09 +00001719std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1720 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001721{
1722 return std::unique_ptr<IWorkload>();
1723}
1724
Derek Lamberti901ea112019-12-10 22:07:09 +00001725std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1726 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001727{
1728 return std::unique_ptr<IWorkload>();
1729}
1730
Derek Lamberti901ea112019-12-10 22:07:09 +00001731std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1732 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001733{
1734 return std::unique_ptr<IWorkload>();
1735}
1736
Derek Lamberti901ea112019-12-10 22:07:09 +00001737std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1738 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001739{
1740 return std::unique_ptr<IWorkload>();
1741}
1742
Derek Lamberti901ea112019-12-10 22:07:09 +00001743std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1744 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001745{
1746 return std::unique_ptr<IWorkload>();
1747}
1748
Derek Lamberti901ea112019-12-10 22:07:09 +00001749std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1750 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001751{
1752 return std::unique_ptr<IWorkload>();
1753}
1754
Derek Lamberti901ea112019-12-10 22:07:09 +00001755std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1756 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001757{
1758 return std::unique_ptr<IWorkload>();
1759}
1760
Derek Lamberti901ea112019-12-10 22:07:09 +00001761std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1762 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001763{
1764 return std::unique_ptr<IWorkload>();
1765}
1766
Derek Lamberti901ea112019-12-10 22:07:09 +00001767std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001768 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001769{
1770 return std::unique_ptr<IWorkload>();
1771}
1772
Derek Lamberti901ea112019-12-10 22:07:09 +00001773std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1774 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001775{
1776 return std::unique_ptr<IWorkload>();
1777}
1778
Derek Lamberti901ea112019-12-10 22:07:09 +00001779std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1780 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001781{
1782 return std::unique_ptr<IWorkload>();
1783}
1784
Derek Lamberti901ea112019-12-10 22:07:09 +00001785std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1786 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001787{
1788 return std::unique_ptr<IWorkload>();
1789}
1790
Derek Lamberti901ea112019-12-10 22:07:09 +00001791std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1792 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001793{
1794 return std::unique_ptr<IWorkload>();
1795}
1796
James Conroy586a9aa2020-03-20 08:49:33 +00001797std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
1798 const WorkloadInfo& /*info*/) const
1799{
1800 return std::unique_ptr<IWorkload>();
1801}
1802
Derek Lamberti901ea112019-12-10 22:07:09 +00001803std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1804 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001805{
1806 return std::unique_ptr<IWorkload>();
1807}
Finn Williams2605b232020-06-10 15:53:46 +01001808std::unique_ptr<IWorkload> IWorkloadFactory::CreateRank(const RankQueueDescriptor& /*descriptor*/,
1809 const WorkloadInfo& /*info*/) const
1810{
1811 return std::unique_ptr<IWorkload>();
1812}
James Conroyee18dc82019-07-17 11:27:46 +01001813
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001814std::unique_ptr<IWorkload> IWorkloadFactory::CreateReduce(const ReduceQueueDescriptor& /*descriptor*/,
1815 const WorkloadInfo& /*info*/) const
1816{
1817 return std::unique_ptr<IWorkload>();
1818}
1819
Derek Lamberti901ea112019-12-10 22:07:09 +00001820std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1821 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001822{
1823 return std::unique_ptr<IWorkload>();
1824}
1825
Derek Lamberti901ea112019-12-10 22:07:09 +00001826std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1827 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001828{
1829 return std::unique_ptr<IWorkload>();
1830}
1831
Keith Davis3ae3f972021-05-21 16:33:48 +01001832std::unique_ptr<IWorkload> IWorkloadFactory::CreateShape(const ShapeQueueDescriptor& /*descriptor*/,
1833 const WorkloadInfo& /*info*/) const
1834{
1835 return std::unique_ptr<IWorkload>();
1836}
1837
Derek Lamberti901ea112019-12-10 22:07:09 +00001838std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1839 const WorkloadInfo& /*info*/) const
1840{
1841 return std::unique_ptr<IWorkload>();
1842}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001843
Derek Lamberti901ea112019-12-10 22:07:09 +00001844std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1845 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001846{
1847 return std::unique_ptr<IWorkload>();
1848}
1849
Derek Lamberti901ea112019-12-10 22:07:09 +00001850std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1851 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001852{
1853 return std::unique_ptr<IWorkload>();
1854}
1855
Derek Lamberti901ea112019-12-10 22:07:09 +00001856std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1857 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001858{
1859 return std::unique_ptr<IWorkload>();
1860}
1861
Derek Lamberti901ea112019-12-10 22:07:09 +00001862std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1863 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001864{
1865 return std::unique_ptr<IWorkload>();
1866}
1867
Derek Lamberti901ea112019-12-10 22:07:09 +00001868std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1869 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001870{
1871 return std::unique_ptr<IWorkload>();
1872}
1873
Derek Lamberti901ea112019-12-10 22:07:09 +00001874std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1875 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001876{
1877 return std::unique_ptr<IWorkload>();
1878}
1879
Derek Lamberti901ea112019-12-10 22:07:09 +00001880std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1881 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001882{
1883 return std::unique_ptr<IWorkload>();
1884}
1885
Derek Lamberti901ea112019-12-10 22:07:09 +00001886std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1887 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001888{
1889 return std::unique_ptr<IWorkload>();
1890}
1891
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001892std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1893 const WorkloadInfo& /*info*/) const
1894{
1895 return std::unique_ptr<IWorkload>();
1896}
1897
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001898std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001899 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1900 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001901{
1902 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001903}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001904
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001905std::unique_ptr<IWorkload> IWorkloadFactory::CreateUnidirectionalSequenceLstm(
1906 const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/,
1907 const WorkloadInfo& /*info*/) const
1908{
1909 return std::unique_ptr<IWorkload>();
1910}
1911
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001912} // namepsace armnn