blob: 55ce3554f9d212eb19ae0aeeab2c09f26b9527ec [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010011#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000017#include <backendsCommon/WorkloadFactory.hpp>
James Conroy1f58f032021-04-27 17:13:27 +010018#include <backendsCommon/TensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000019
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
David Beck111b5d92018-11-12 14:59:37 +000022#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000023
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
telsoa01c577f2c2018-08-31 09:22:23 +010027namespace
28{
Finn Williams3e54d032020-10-22 16:53:35 +010029using LayerList = std::list<Layer*>;
30using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010031
David Beck29c75de2018-10-23 13:35:58 +010032const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
33{
34 if (!type)
35 {
36 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010037 }
38
Matthew Sloyan81beae32021-07-13 19:46:11 +010039 return TensorInfo(info.GetShape(),
40 type.value(),
41 info.GetQuantizationScale(),
42 info.GetQuantizationOffset(),
43 info.IsConstant());
telsoa01c577f2c2018-08-31 09:22:23 +010044}
45
David Beck29c75de2018-10-23 13:35:58 +010046} // anonymous namespace
47
Sadik Armagan045f6be2020-09-10 13:37:32 +010048bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
49 const IConnectableLayer& connectableLayer,
50 Optional<DataType> dataType,
51 std::string& outReasonIfUnsupported,
52 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000053{
David Beck33f0ae02018-10-18 15:13:56 +010054 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000055 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010056 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010057
David Beck111b5d92018-11-12 14:59:37 +000058 auto const& backendRegistry = BackendRegistryInstance();
59 if (!backendRegistry.IsBackendRegistered(backendId))
60 {
61 std::stringstream ss;
62 ss << connectableLayer.GetName() << " is not supported on " << backendId
63 << " because this backend is not registered.";
64
65 outReasonIfUnsupported = ss.str();
66 return false;
67 }
68
69 auto backendFactory = backendRegistry.GetFactory(backendId);
70 auto backendObject = backendFactory();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000071 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
David Beck33f0ae02018-10-18 15:13:56 +010072
telsoa014fcda012018-03-09 14:13:49 +000073 switch(layer.GetType())
74 {
75 case LayerType::Activation:
76 {
Jan Eilersbb446e52020-04-02 13:56:54 +010077 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +000078 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010079 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000080 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010081 OverrideDataType(input, dataType),
82 OverrideDataType(output, dataType),
83 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010084 reason);
telsoa014fcda012018-03-09 14:13:49 +000085 break;
86 }
87 case LayerType::Addition:
88 {
89 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
90 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
91 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000092 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010093 OverrideDataType(input0, dataType),
94 OverrideDataType(input1, dataType),
95 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010096 reason);
telsoa014fcda012018-03-09 14:13:49 +000097 break;
98 }
Nikhil Rajee391d52019-09-05 17:50:44 +010099 case LayerType::ArgMinMax:
100 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100101 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +0100102 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
103
104 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000106 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100107 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000108 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100109 descriptor,
110 reason);
111 break;
112 }
telsoa014fcda012018-03-09 14:13:49 +0000113 case LayerType::BatchNormalization:
114 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100115 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000122 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100137 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000138
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000139 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000143 break;
144 }
mathad01b392e982021-04-07 12:07:30 +0100145 case LayerType::Cast:
146 {
147 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
148 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
149
150 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
151 OverrideDataType(output, dataType),
152 reason);
153 break;
154 }
Simon Obute51f67772021-09-03 15:50:13 +0100155 case LayerType::ChannelShuffle:
156 {
157 auto cLayer = PolymorphicDowncast<const ChannelShuffleLayer*>(&layer);
158
159 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
160 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161
162 const ChannelShuffleDescriptor descriptor = cLayer->GetParameters();
163
164 result = layerSupportObject.IsChannelShuffleSupported(OverrideDataType(input, dataType),
165 OverrideDataType(output, dataType),
166 descriptor,
167 reason);
168 break;
169 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100170 case LayerType::Comparison:
171 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100172 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100173
174 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
175 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
176 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
177
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000178 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
179 OverrideDataType(input1, dataType),
180 OverrideDataType(output, DataType::Boolean),
181 cLayer->GetParameters(),
182 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100183 break;
184 }
telsoa014fcda012018-03-09 14:13:49 +0000185 case LayerType::Constant:
186 {
187 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000188 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 break;
190 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000191 case LayerType::ConvertBf16ToFp32:
192 {
193 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
194 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000195 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000196 break;
197 }
telsoa01c577f2c2018-08-31 09:22:23 +0100198 case LayerType::ConvertFp16ToFp32:
199 {
200 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
201 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000202 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100203 break;
204 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000205 case LayerType::ConvertFp32ToBf16:
206 {
207 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
208 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000209 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000210 break;
211 }
telsoa01c577f2c2018-08-31 09:22:23 +0100212 case LayerType::ConvertFp32ToFp16:
213 {
214 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
215 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000216 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000217 break;
218 }
219 case LayerType::Convolution2d:
220 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100221 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100222
223 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
224 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100225 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100226 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
surmeh013537c2c2018-05-18 16:31:43 +0100227
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100228 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100229
arovir01a6824102018-08-28 17:40:45 +0100230 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100231 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100232 if (descriptor.m_BiasEnabled)
233 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100234 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100235 }
236
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000237 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100238 input,
239 output,
240 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100241 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100242 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100243 reason);
telsoa014fcda012018-03-09 14:13:49 +0000244 break;
245 }
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100246 case LayerType::Convolution3d:
247 {
248 auto cLayer = PolymorphicDowncast<const Convolution3dLayer*>(&layer);
249
250 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
251 dataType);
252 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100253
254 ARMNN_ASSERT_MSG(layer.GetInputSlot(1).GetConnection(),
255 "Convolution3dLayer: Weights should be connected as a Constant Layer.");
256 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
257 dataType);
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100258
259 const Convolution3dDescriptor& descriptor = cLayer->GetParameters();
260
261 // Construct optional biases object based on the value of m_BiasEnabled
262 Optional<TensorInfo> biases;
263 if (descriptor.m_BiasEnabled)
264 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100265 biases = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
266 GetBiasTypeFromWeightsType(dataType));
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100267 }
268
269 result = layerSupportObject.IsConvolution3dSupported(
270 input,
271 output,
272 descriptor,
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100273 weights,
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100274 biases,
275 reason);
276 break;
277 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000278 case LayerType::Debug:
279 {
280 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
281 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
282
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000283 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000284 OverrideDataType(output, dataType),
285 reason);
286 break;
287 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100288 case LayerType::DepthToSpace:
289 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100290 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100291
292 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
293 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
294
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000295 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100296 OverrideDataType(output, dataType),
297 cLayer->GetParameters(),
298 reason);
299 break;
300 }
telsoa014fcda012018-03-09 14:13:49 +0000301 case LayerType::DepthwiseConvolution2d:
302 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100303 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100304 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
305 dataType);
306 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100307 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100308
telsoa01c577f2c2018-08-31 09:22:23 +0100309 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100310
311 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100312 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100313 if (descriptor.m_BiasEnabled)
314 {
David Beck5eec11d2018-10-04 15:43:17 +0100315 biases =
316 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100317 }
telsoa01c577f2c2018-08-31 09:22:23 +0100318
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000319 result = layerSupportObject.IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100320 input,
321 output,
322 descriptor,
323 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100324 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100325 reason);
telsoa014fcda012018-03-09 14:13:49 +0000326 break;
327 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000328 case LayerType::Dequantize:
329 {
330 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
331 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
332
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000333 result = layerSupportObject.IsDequantizeSupported(input,
334 OverrideDataType(output, dataType),
335 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000336 break;
337 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000338 case LayerType::DetectionPostProcess:
339 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100340 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000341 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
342 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
343 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
344
345 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
346 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
347 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
348 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
349
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000350 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000351 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
352 scores,
353 anchors,
354 detectionBoxes,
355 detectionClasses,
356 detectionScores,
357 numDetections,
358 descriptor,
359 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000360 break;
361 }
josh minor4a3c6102020-01-06 16:40:46 -0600362 case LayerType::ElementwiseUnary:
363 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100364 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600365
366 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
367 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
368
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000369 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
370 OverrideDataType(output, dataType),
371 cLayer->GetParameters(),
372 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600373 break;
374 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100375 case LayerType::Fill:
376 {
377 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
378 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
379 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
380 const FillDescriptor& descriptor = cLayer->GetParameters();
381
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000382 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100383 OverrideDataType(input, dataType),
384 OverrideDataType(output, dataType),
385 descriptor,
386 reason);
387 break;
388 }
telsoa014fcda012018-03-09 14:13:49 +0000389 case LayerType::FakeQuantization:
390 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100391 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000392 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000393 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
394 cLayer->GetParameters(),
395 reason);
telsoa014fcda012018-03-09 14:13:49 +0000396 break;
397 }
398 case LayerType::Floor:
399 {
400 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
401 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000402 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
403 OverrideDataType(output, dataType),
404 reason);
telsoa014fcda012018-03-09 14:13:49 +0000405 break;
406 }
407 case LayerType::FullyConnected:
408 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100409 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000410 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100411 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000412
413 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
414 TensorInfo weightsInfo;
415 const TensorInfo* weightsInfoPtr = nullptr;
416
Matthew Sloyan81beae32021-07-13 19:46:11 +0100417 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000418 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100419
420 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000421 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000422 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100423 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
424 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
425 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
426
telsoa01c577f2c2018-08-31 09:22:23 +0100427 if (descriptor.m_BiasEnabled)
428 {
Matthew Sloyan81beae32021-07-13 19:46:11 +0100429 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
430 biasInfoPtr = &biasInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100431 }
432 else
433 {
434 // If biases are not enabled pass a dummy tensorinfo for the validation
435 switch(input.GetDataType())
436 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000437 case DataType::BFloat16:
438 {
439 biasInfoPtr = &dummyBFloat16Bias;
440 break;
441 }
telsoa01c577f2c2018-08-31 09:22:23 +0100442 case DataType::Float16:
443 {
444 biasInfoPtr = &dummyFloat16Bias;
445 break;
446 }
447 case DataType::Float32:
448 {
449 biasInfoPtr = &dummyFloat32Bias;
450 break;
451 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000452 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000453 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000454 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000455 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100456 {
457 biasInfoPtr = &dummyQA8Bias;
458 break;
459 }
460 default:
461 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100462 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100463 }
464 }
465 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000466 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100467 OverrideDataType(input, dataType),
468 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000469 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100470 *biasInfoPtr,
471 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100472 reason);
telsoa014fcda012018-03-09 14:13:49 +0000473 break;
474 }
narpra01b89b05f2019-01-16 09:53:09 +0000475 case LayerType::Gather:
476 {
477 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
478 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
479 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100480 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
481 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000482 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
483 input1,
484 OverrideDataType(output, dataType),
485 descriptor,
486 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000487 break;
488 }
telsoa014fcda012018-03-09 14:13:49 +0000489 case LayerType::Input:
490 {
491 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000492 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000493 break;
494 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100495 case LayerType::InstanceNormalization:
496 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100497 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100498 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
499
500 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
501 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
502
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000503 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100504 OverrideDataType(input, dataType),
505 OverrideDataType(output, dataType),
506 descriptor,
507 reason);
508 break;
509 }
telsoa014fcda012018-03-09 14:13:49 +0000510 case LayerType::L2Normalization:
511 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100512 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100513 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
514
telsoa014fcda012018-03-09 14:13:49 +0000515 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100516 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100517
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000518 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100519 OverrideDataType(input, dataType),
520 OverrideDataType(output, dataType),
521 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100522 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100523 break;
524 }
James Conroyaba90cd2020-11-06 16:28:18 +0000525 case LayerType::LogicalBinary:
526 {
527 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
528
529 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
530 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
531 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
532
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000533 result = layerSupportObject.IsLogicalBinarySupported(input0,
534 input1,
535 output,
536 cLayer->GetParameters(),
537 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000538 break;
539 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100540 case LayerType::LogSoftmax:
541 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100542 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100543
544 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
545 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
546
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000547 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
548 OverrideDataType(output, dataType),
549 cLayer->GetParameters(),
550 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100551 break;
552 }
telsoa01c577f2c2018-08-31 09:22:23 +0100553 case LayerType::Lstm:
554 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100555 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100556 const LstmDescriptor& descriptor = cLayer->GetParameters();
557
558 // All inputs.
559 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
560 dataType);
561 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
562 dataType);
563 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
564 dataType);
565 // All outputs
566 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
567 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
568 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
569 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
570
571 // Basic parameters
572 const TensorInfo& inputToForgetWeights
573 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
574 const TensorInfo& inputToCellWeights
575 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
576 const TensorInfo& inputToOutputWeights
577 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
578 const TensorInfo& recurrentToForgetWeights
579 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
580 const TensorInfo& recurrentToCellWeights
581 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
582 const TensorInfo& recurrentToOutputWeights
583 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
584 const TensorInfo& forgetGateBias
585 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
586 const TensorInfo& cellBias
587 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
588 const TensorInfo& outputGateBias
589 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
590
Jan Eilersd01a83c2019-07-03 18:20:40 +0100591 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100592
Jan Eilersd01a83c2019-07-03 18:20:40 +0100593 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
594 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
595 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
596 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
597 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
598 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
599 paramsInfo.m_ForgetGateBias = &forgetGateBias;
600 paramsInfo.m_CellBias = &cellBias;
601 paramsInfo.m_OutputGateBias = &outputGateBias;
602
603
604 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100605 TensorInfo optInputToInputWeights;
606 TensorInfo optRecurrentToInputWeights;
607 TensorInfo optCellToInputWeights;
608 TensorInfo optInputGateBias;
609 TensorInfo optProjectionWeights;
610 TensorInfo optProjectionBias;
611 TensorInfo optCellToForgetWeights;
612 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100613 TensorInfo optInputLayerNormWeights;
614 TensorInfo optForgetLayerNormWeights;
615 TensorInfo optCellLayerNormWeights;
616 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100617
618 if(!descriptor.m_CifgEnabled)
619 {
620 optInputToInputWeights =
621 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100622 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100623
624 optRecurrentToInputWeights =
625 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100626 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100627 optInputGateBias =
628 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100629 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100630 }
631
632 if(descriptor.m_ProjectionEnabled)
633 {
634 optProjectionWeights =
635 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100636 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100637 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
638 {
639 optProjectionBias =
640 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100641 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100642 }
643 }
644
645 if(descriptor.m_PeepholeEnabled)
646 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100647 if(!descriptor.m_CifgEnabled)
648 {
649 optCellToInputWeights =
650 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
651 dataType);
652 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
653 }
telsoa01c577f2c2018-08-31 09:22:23 +0100654 optCellToForgetWeights =
655 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100656 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100657 optCellToOutputWeights =
658 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100659 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100660 }
661
Jan Eilers38e05bd2019-06-26 13:10:09 +0100662 if(descriptor.m_LayerNormEnabled)
663 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100664 if (!descriptor.m_CifgEnabled)
665 {
666 optInputLayerNormWeights = OverrideDataType(
667 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
668 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
669 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100670
671 optForgetLayerNormWeights = OverrideDataType(
672 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100673 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100674
675 optCellLayerNormWeights = OverrideDataType(
676 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100677 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100678
679 optOutputLayerNormWeights = OverrideDataType(
680 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100681 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100682 }
683
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000684 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100685 input,
686 outputStateIn,
687 cellStateIn,
688 scratchBuffer,
689 outputStateOut,
690 cellStateOut,
691 output,
692 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100693 paramsInfo,
694 reason);
telsoa014fcda012018-03-09 14:13:49 +0000695 break;
696 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000697 case LayerType::Maximum:
698 {
699 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
700 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
701 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
702
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000703 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
704 OverrideDataType(input1, dataType),
705 OverrideDataType(output, dataType),
706 reason);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000707 break;
708 }
narpra01b89b05f2019-01-16 09:53:09 +0000709 case LayerType::MemCopy:
710 {
711 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
712 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000713
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000714 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
715 OverrideDataType(output, dataType),
716 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000717 break;
718 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100719 case LayerType::MemImport:
720 {
721 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
722 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
723
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000724 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
725 OverrideDataType(output, dataType),
726 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100727 break;
728 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100729 case LayerType::Merge:
730 {
731 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
732 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
733 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
734
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000735 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
736 OverrideDataType(input1, dataType),
737 OverrideDataType(output, dataType),
738 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100739 break;
740 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100741 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000742 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100743 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000744
telsoa01c577f2c2018-08-31 09:22:23 +0100745 // Get vector of all inputs.
746 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000747 {
telsoa01c577f2c2018-08-31 09:22:23 +0100748 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000749 };
Finn Williams3e54d032020-10-22 16:53:35 +0100750
751 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
752 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100753 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000754
telsoa01c577f2c2018-08-31 09:22:23 +0100755 auto getTensorInfoPtr = [](const TensorInfo& info)
756 {
757 return &info;
758 };
Finn Williams3e54d032020-10-22 16:53:35 +0100759
760 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
761 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100762 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000763
Nikhil Raj8599a412018-11-19 14:51:07 +0000764 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
765
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000766 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100767
768
telsoa014fcda012018-03-09 14:13:49 +0000769 break;
770 }
771 case LayerType::Multiplication:
772 {
773 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
774 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100775 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000776 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100777 OverrideDataType(input0, dataType),
778 OverrideDataType(input1, dataType),
779 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100780 reason);
telsoa014fcda012018-03-09 14:13:49 +0000781 break;
782 }
783 case LayerType::Normalization:
784 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100785 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000786 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
787 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000788 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
789 OverrideDataType(output, dataType),
790 cLayer->GetParameters(),
791 reason);
telsoa014fcda012018-03-09 14:13:49 +0000792 break;
793 }
794 case LayerType::Output:
795 {
796 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000797 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000798 break;
799 }
800 case LayerType::Permute:
801 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100802 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000803 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
804 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000805 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
806 OverrideDataType(output, dataType),
807 cLayer->GetParameters(),
808 reason);
telsoa014fcda012018-03-09 14:13:49 +0000809 break;
810 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100811 case LayerType::Pad:
812 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100813 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100814 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
815 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000816 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100817 OverrideDataType(input, dataType),
818 OverrideDataType(output, dataType),
819 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100820 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100821 break;
822 }
telsoa014fcda012018-03-09 14:13:49 +0000823 case LayerType::Pooling2d:
824 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100825 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000826 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
827 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000828 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
829 OverrideDataType(output, dataType),
830 cLayer->GetParameters(),
831 reason);
telsoa014fcda012018-03-09 14:13:49 +0000832 break;
833 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000834 case LayerType::PreCompiled:
835 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100836 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000837 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000838 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
839 cLayer->GetParameters(),
840 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000841 break;
842 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000843 case LayerType::Quantize:
844 {
845 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
846 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000847 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000848 break;
849 }
James Conroy586a9aa2020-03-20 08:49:33 +0000850 case LayerType::QLstm:
851 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100852 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000853 const QLstmDescriptor& descriptor = cLayer->GetParameters();
854
855 // Inputs
856 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
857 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
858 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
859
860 // Outputs
861 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
862 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
863 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
864
865 // Lstm parameters
866 LstmInputParamsInfo paramsInfo;
867
868 // Basic parameters
Matthew Bentham6f24b1a2021-06-29 15:18:32 +0100869 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() != nullptr);
870 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() != nullptr);
871 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() != nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000872 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
873 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
874 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
875
876 paramsInfo.m_RecurrentToForgetWeights =
877 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
878 paramsInfo.m_RecurrentToCellWeights =
879 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
880 paramsInfo.m_RecurrentToOutputWeights =
881 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
882
883 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
884 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
885 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
886
887 if(!descriptor.m_CifgEnabled)
888 {
889 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
890 paramsInfo.m_RecurrentToInputWeights =
891 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
892 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
893 }
894
895 if(descriptor.m_ProjectionEnabled)
896 {
897 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100898
899 // Projection bias is optional even if projection is enabled
900 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
901 {
902 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
903 }
James Conroy586a9aa2020-03-20 08:49:33 +0000904 }
905
906 if(descriptor.m_PeepholeEnabled)
907 {
908 if (!descriptor.m_CifgEnabled)
909 {
910 paramsInfo.m_CellToInputWeights =
911 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
912 }
913
914 paramsInfo.m_CellToForgetWeights =
915 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
916 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
917 }
918
919 if(descriptor.m_LayerNormEnabled)
920 {
921 if (!descriptor.m_CifgEnabled)
922 {
923 paramsInfo.m_InputLayerNormWeights =
924 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
925 }
926
927 paramsInfo.m_ForgetLayerNormWeights =
928 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
929 paramsInfo.m_CellLayerNormWeights =
930 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
931 paramsInfo.m_OutputLayerNormWeights =
932 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
933 }
934
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000935 result = layerSupportObject.IsQLstmSupported(input,
936 previousOutputIn,
937 previousCellStateIn,
938 outputStateOut,
939 cellStateOut,
940 output,
941 descriptor,
942 paramsInfo,
943 reason);
James Conroy586a9aa2020-03-20 08:49:33 +0000944 break;
945 }
James Conroyee18dc82019-07-17 11:27:46 +0100946 case LayerType::QuantizedLstm:
947 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100948 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +0100949
950 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100951 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
952 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
953 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100954
955 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100956 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
957 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100958
959 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100960 QuantizedLstmInputParamsInfo paramsInfo;
961
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100962 paramsInfo.m_InputToInputWeights =
963 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
964 paramsInfo.m_InputToForgetWeights =
965 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
966 paramsInfo.m_InputToCellWeights =
967 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
968 paramsInfo.m_InputToOutputWeights =
969 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100970
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100971 paramsInfo.m_RecurrentToInputWeights =
972 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
973 paramsInfo.m_RecurrentToForgetWeights =
974 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
975 paramsInfo.m_RecurrentToCellWeights =
976 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
977 paramsInfo.m_RecurrentToOutputWeights =
978 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100979
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100980 paramsInfo.m_InputGateBias =
981 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
982 paramsInfo.m_ForgetGateBias =
983 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
984 paramsInfo.m_CellBias =
985 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
986 paramsInfo.m_OutputGateBias =
987 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100988
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000989 result = layerSupportObject.IsQuantizedLstmSupported(input,
990 previousCellStateIn,
991 previousOutputIn,
992 cellStateOut,
993 output,
994 paramsInfo,
995 reason);
James Conroyee18dc82019-07-17 11:27:46 +0100996 break;
997 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100998 case LayerType::Division:
999 {
1000 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1001 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1002 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001003 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001004 OverrideDataType(input0, dataType),
1005 OverrideDataType(input1, dataType),
1006 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001007 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001008 break;
1009 }
Finn Williams2605b232020-06-10 15:53:46 +01001010 case LayerType::Rank:
1011 {
1012 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1013 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001014 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
1015 OverrideDataType(output, dataType),
1016 reason);
Finn Williams2605b232020-06-10 15:53:46 +01001017 break;
1018 }
telsoa014fcda012018-03-09 14:13:49 +00001019 case LayerType::Reshape:
1020 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001021 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001022 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +00001023 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001024 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
1025 OverrideDataType(output, dataType),
1026 cLayer->GetParameters(),
1027 reason);
telsoa014fcda012018-03-09 14:13:49 +00001028 break;
1029 }
Teresa Charlina9075df2019-06-27 15:41:57 +01001030 case LayerType::Resize:
1031 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001032 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +01001033 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +01001034 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001035 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1036 OverrideDataType(output, dataType),
1037 cLayer->GetParameters(),
1038 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +01001039 break;
1040 }
Keith Davis3ae3f972021-05-21 16:33:48 +01001041 case LayerType::Shape:
1042 {
1043 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1044 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1045
1046 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1047 OverrideDataType(output, dataType),
1048 reason);
1049 break;
1050 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001051 case LayerType::Slice:
1052 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001053 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001054
1055 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1056 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1057
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001058 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1059 OverrideDataType(output, dataType),
1060 cLayer->GetParameters(),
1061 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001062 break;
1063 }
telsoa014fcda012018-03-09 14:13:49 +00001064 case LayerType::Softmax:
1065 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001066 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001067 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001068 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001069 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1070 OverrideDataType(output, dataType),
1071 cLayer->GetParameters(),
1072 reason);
telsoa014fcda012018-03-09 14:13:49 +00001073 break;
1074 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001075 case LayerType::SpaceToBatchNd:
1076 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001077 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001078 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1079 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001080 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1081 OverrideDataType(output, dataType),
1082 cLayer->GetParameters(),
1083 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001084 break;
1085 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001086 case LayerType::SpaceToDepth:
1087 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001088 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001089
1090 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1091 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1092
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001093 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1094 OverrideDataType(output, dataType),
1095 cLayer->GetParameters(),
1096 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001097 break;
1098 }
telsoa014fcda012018-03-09 14:13:49 +00001099 case LayerType::Splitter:
1100 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001101 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001102 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001103
1104 // Get vector of all outputs.
1105 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1106 {
1107 return OverrideDataType(slot.GetTensorInfo(), dataType);
1108 };
Finn Williams3e54d032020-10-22 16:53:35 +01001109 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1110 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001111 std::vector<TensorInfo> outputs(beginI, endI);
1112
1113 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1114
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001115 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1116 outputPtrs,
1117 cLayer->GetParameters(),
1118 reason);
telsoa014fcda012018-03-09 14:13:49 +00001119 break;
1120 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001121 case LayerType::Stack:
1122 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001123 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001124
1125 // Get vector of all inputs.
1126 auto getTensorInfo = [&dataType](const InputSlot& slot)
1127 {
1128 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1129 };
Finn Williams3e54d032020-10-22 16:53:35 +01001130 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1131 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001132 std::vector<TensorInfo> inputs(beginI, endI);
1133
1134 auto getTensorInfoPtr = [](const TensorInfo& info)
1135 {
1136 return &info;
1137 };
Finn Williams3e54d032020-10-22 16:53:35 +01001138 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1139 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001140 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1141
1142 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1143
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001144 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001145
1146 break;
1147 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001148 case LayerType::StandIn:
1149 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001150 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001151
1152 // Get vector of all inputs.
1153 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1154 {
1155 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1156 };
1157 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1158 {
1159 return OverrideDataType(slot.GetTensorInfo(), dataType);
1160 };
Finn Williams3e54d032020-10-22 16:53:35 +01001161 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1162 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001163 std::vector<TensorInfo> inputs(beginI, endI);
1164
Finn Williams3e54d032020-10-22 16:53:35 +01001165 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1166 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001167 std::vector<TensorInfo> outputs(beginO, endO);
1168
1169
1170 auto getTensorInfoPtr = [](const TensorInfo& info)
1171 {
1172 return &info;
1173 };
Finn Williams3e54d032020-10-22 16:53:35 +01001174 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1175 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001176 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1177
Finn Williams3e54d032020-10-22 16:53:35 +01001178 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1179 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001180 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1181
1182
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001183 result = layerSupportObject.IsStandInSupported(inputPtrs,
1184 outputPtrs,
1185 cLayer->GetParameters(),
1186 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001187 break;
1188 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001189 case LayerType::StridedSlice:
1190 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001191 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001192 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1193 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001194 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1195 OverrideDataType(output, dataType),
1196 cLayer->GetParameters(),
1197 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001198 break;
1199 }
David Beckc2044fe2018-09-05 15:00:38 +01001200 case LayerType::Subtraction:
1201 {
1202 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1203 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1204 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001205 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001206 OverrideDataType(input0, dataType),
1207 OverrideDataType(input1, dataType),
1208 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001209 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001210 break;
1211 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001212 case LayerType::Switch:
1213 {
1214 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1215 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1216 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1217 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001218 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1219 OverrideDataType(input1, dataType),
1220 OverrideDataType(output0, dataType),
1221 OverrideDataType(output1, dataType),
1222 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001223 break;
1224 }
narpra0132b90462018-09-13 11:07:48 +01001225 case LayerType::Mean:
1226 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001227 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001228 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1229 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001230 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001231 OverrideDataType(input, dataType),
1232 OverrideDataType(output, dataType),
1233 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001234 reason);
narpra0132b90462018-09-13 11:07:48 +01001235 break;
1236 }
kevmay0190539692018-11-29 08:40:19 +00001237 case LayerType::Minimum:
1238 {
1239 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1240 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1241 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001242 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1243 OverrideDataType(input1, dataType),
1244 OverrideDataType(output, dataType),
1245 reason);
kevmay0190539692018-11-29 08:40:19 +00001246 break;
1247 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001248 case LayerType::Prelu:
1249 {
1250 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1251 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1252 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001253 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1254 OverrideDataType(alpha, dataType),
1255 OverrideDataType(output, dataType),
1256 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001257 break;
1258 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001259 case LayerType::Transpose:
1260 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001261 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001262 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1263 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001264 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1265 OverrideDataType(output, dataType),
1266 cLayer->GetParameters(),
1267 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001268 break;
1269 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001270 case LayerType::TransposeConvolution2d:
1271 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001272 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001273
1274 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1275 dataType);
1276 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1277
1278 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1279
1280 Optional<TensorInfo> biases;
1281 if (descriptor.m_BiasEnabled)
1282 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001283 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001284 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1285 GetBiasTypeFromWeightsType(dataType));
1286 }
1287
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001288 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001289 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1290
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001291 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1292 output,
1293 descriptor,
1294 weights,
1295 biases,
1296 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001297
1298 break;
1299 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001300 case LayerType::Reduce:
1301 {
1302 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1303 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1304 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1305
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001306 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1307 OverrideDataType(output, dataType),
1308 cLayer->GetParameters(),
1309 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001310 break;
1311 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001312 case LayerType::UnidirectionalSequenceLstm:
1313 {
1314 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1315 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1316
1317 // All inputs.
1318 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1319 dataType);
1320 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
1321 dataType);
1322 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
1323 dataType);
1324 // Outputs
1325 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1326
1327 // Basic parameters
1328 const TensorInfo& inputToForgetWeights
1329 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1330 const TensorInfo& inputToCellWeights
1331 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1332 const TensorInfo& inputToOutputWeights
1333 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1334 const TensorInfo& recurrentToForgetWeights
1335 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1336 const TensorInfo& recurrentToCellWeights
1337 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1338 const TensorInfo& recurrentToOutputWeights
1339 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1340 const TensorInfo& forgetGateBias
1341 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1342 const TensorInfo& cellBias
1343 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1344 const TensorInfo& outputGateBias
1345 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1346
1347 LstmInputParamsInfo paramsInfo;
1348
1349 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1350 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1351 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1352 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1353 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1354 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1355 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1356 paramsInfo.m_CellBias = &cellBias;
1357 paramsInfo.m_OutputGateBias = &outputGateBias;
1358
1359 // Optional parameters
1360 TensorInfo optInputToInputWeights;
1361 TensorInfo optRecurrentToInputWeights;
1362 TensorInfo optCellToInputWeights;
1363 TensorInfo optInputGateBias;
1364 TensorInfo optProjectionWeights;
1365 TensorInfo optProjectionBias;
1366 TensorInfo optCellToForgetWeights;
1367 TensorInfo optCellToOutputWeights;
1368 TensorInfo optInputLayerNormWeights;
1369 TensorInfo optForgetLayerNormWeights;
1370 TensorInfo optCellLayerNormWeights;
1371 TensorInfo optOutputLayerNormWeights;
1372
1373 if(!descriptor.m_CifgEnabled)
1374 {
1375 optInputToInputWeights =
1376 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1377 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1378
1379 optRecurrentToInputWeights =
1380 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1381 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1382 optInputGateBias =
1383 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1384 paramsInfo.m_InputGateBias = &optInputGateBias;
1385 }
1386
1387 if(descriptor.m_ProjectionEnabled)
1388 {
1389 optProjectionWeights =
1390 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1391 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1392 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1393 {
1394 optProjectionBias =
1395 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1396 paramsInfo.m_ProjectionBias = &optProjectionBias;
1397 }
1398 }
1399
1400 if(descriptor.m_PeepholeEnabled)
1401 {
1402 if(!descriptor.m_CifgEnabled)
1403 {
1404 optCellToInputWeights =
1405 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1406 dataType);
1407 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1408 }
1409 optCellToForgetWeights =
1410 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1411 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1412 optCellToOutputWeights =
1413 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1414 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1415 }
1416
1417 if(descriptor.m_LayerNormEnabled)
1418 {
1419 if (!descriptor.m_CifgEnabled)
1420 {
1421 optInputLayerNormWeights = OverrideDataType(
1422 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1423 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1424 }
1425
1426 optForgetLayerNormWeights = OverrideDataType(
1427 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1428 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1429
1430 optCellLayerNormWeights = OverrideDataType(
1431 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1432 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1433
1434 optOutputLayerNormWeights = OverrideDataType(
1435 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1436 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1437 }
1438
1439 Optional<TensorInfo> hiddenStateOut;
1440 Optional<TensorInfo> cellStateOut;
1441
1442 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1443 outputStateIn,
1444 cellStateIn,
1445 output,
1446 hiddenStateOut,
1447 cellStateOut,
1448 descriptor,
1449 paramsInfo,
1450 reason);
1451 break;
1452 }
telsoa014fcda012018-03-09 14:13:49 +00001453 default:
1454 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001455 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001456 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001457 result = false;
1458 break;
1459 }
1460 }
telsoa014fcda012018-03-09 14:13:49 +00001461 return result;
1462}
1463
Sadik Armagan045f6be2020-09-10 13:37:32 +01001464bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1465 const IConnectableLayer& connectableLayer,
1466 Optional<DataType> dataType,
1467 std::string& outReasonIfUnsupported)
1468{
1469 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1470}
1471
David Beckdcb751f2018-10-03 11:42:42 +01001472bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001473 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001474 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001475{
Jan Eilersbb446e52020-04-02 13:56:54 +01001476 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001477 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1478}
1479
1480// TODO merge with defaulted modelOptions above
1481bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1482 Optional<DataType> dataType,
1483 std::string& outReasonIfUnsupported,
1484 const ModelOptions& modelOptions)
1485{
1486 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1487 return IsLayerConfigurationSupported(layer->GetBackendId(),
1488 connectableLayer,
1489 dataType,
1490 outReasonIfUnsupported,
1491 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001492}
1493
Sadik Armagan04a72972020-09-14 15:44:18 +01001494bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1495 const IConnectableLayer& connectableLayer,
1496 Optional<DataType> dataType,
1497 std::string& outReasonIfUnsupported,
1498 const ModelOptions& modelOptions)
1499{
1500 return IsLayerConfigurationSupported(backendId,
1501 connectableLayer,
1502 dataType,
1503 outReasonIfUnsupported,
1504 modelOptions);
1505}
1506
Derek Lamberti901ea112019-12-10 22:07:09 +00001507std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1508 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001509{
1510 return std::unique_ptr<IWorkload>();
1511}
1512
Derek Lamberti901ea112019-12-10 22:07:09 +00001513std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1514 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001515{
1516 return std::unique_ptr<IWorkload>();
1517}
1518
Derek Lamberti901ea112019-12-10 22:07:09 +00001519std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1520 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001521{
1522 return std::unique_ptr<IWorkload>();
1523}
1524
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001525std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001526 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001527{
1528 return std::unique_ptr<IWorkload>();
1529}
1530
Derek Lamberti901ea112019-12-10 22:07:09 +00001531std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1532 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001533{
1534 return std::unique_ptr<IWorkload>();
1535}
1536
mathad01b392e982021-04-07 12:07:30 +01001537std::unique_ptr<IWorkload> IWorkloadFactory::CreateCast(const CastQueueDescriptor& /*descriptor*/,
1538 const WorkloadInfo& /*info*/) const
1539{
1540 return std::unique_ptr<IWorkload>();
1541}
1542
Simon Obute51f67772021-09-03 15:50:13 +01001543std::unique_ptr<IWorkload> IWorkloadFactory::CreateChannelShuffle(const ChannelShuffleQueueDescriptor& /*descriptor*/,
1544 const WorkloadInfo& /*info*/) const
1545{
1546 return std::unique_ptr<IWorkload>();
1547}
1548
Derek Lamberti901ea112019-12-10 22:07:09 +00001549std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1550 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001551{
1552 return std::unique_ptr<IWorkload>();
1553}
1554
Derek Lamberti901ea112019-12-10 22:07:09 +00001555std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1556 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001557{
1558 return std::unique_ptr<IWorkload>();
1559}
1560
Derek Lamberti901ea112019-12-10 22:07:09 +00001561std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1562 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001563{
1564 return std::unique_ptr<IWorkload>();
1565}
1566
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001567std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1568 const WorkloadInfo& /*info*/) const
1569{
1570 return std::unique_ptr<IWorkload>();
1571}
1572
Derek Lamberti901ea112019-12-10 22:07:09 +00001573std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1574 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001575{
1576 return std::unique_ptr<IWorkload>();
1577}
1578
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001579std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1580 const WorkloadInfo& /*info*/) const
1581{
1582 return std::unique_ptr<IWorkload>();
1583}
1584
Derek Lamberti901ea112019-12-10 22:07:09 +00001585std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1586 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001587{
1588 return std::unique_ptr<IWorkload>();
1589}
1590
Derek Lamberti901ea112019-12-10 22:07:09 +00001591std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1592 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001593{
1594 return std::unique_ptr<IWorkload>();
1595}
1596
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001597std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution3d(const Convolution3dQueueDescriptor& /*descriptor*/,
1598 const WorkloadInfo& /*info*/) const
1599{
1600 return std::unique_ptr<IWorkload>();
1601}
1602
Derek Lamberti901ea112019-12-10 22:07:09 +00001603std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1604 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001605{
1606 return std::unique_ptr<IWorkload>();
1607}
1608
Derek Lamberti901ea112019-12-10 22:07:09 +00001609std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1610 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001611{
1612 return std::unique_ptr<IWorkload>();
1613}
1614
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001615std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001616 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001617{
1618 return std::unique_ptr<IWorkload>();
1619}
1620
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001621std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001622 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001623{
1624 return std::unique_ptr<IWorkload>();
1625}
1626
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001627std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001628 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001629{
1630 return std::unique_ptr<IWorkload>();
1631}
1632
Derek Lamberti901ea112019-12-10 22:07:09 +00001633std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1634 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001635{
1636 return std::unique_ptr<IWorkload>();
1637}
1638
josh minor4a3c6102020-01-06 16:40:46 -06001639std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1640 const WorkloadInfo& /*info*/) const
1641{
1642 return std::unique_ptr<IWorkload>();
1643}
1644
Derek Lamberti901ea112019-12-10 22:07:09 +00001645std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1646 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001647{
1648 return std::unique_ptr<IWorkload>();
1649}
1650
Ryan OSheaec6c6802020-06-05 17:17:06 +01001651std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
1652 const WorkloadInfo& /*info*/) const
1653{
1654 return std::unique_ptr<IWorkload>();
1655}
1656
Derek Lamberti901ea112019-12-10 22:07:09 +00001657std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1658 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001659{
1660 return std::unique_ptr<IWorkload>();
1661}
1662
Derek Lamberti901ea112019-12-10 22:07:09 +00001663std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1664 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001665{
1666 return std::unique_ptr<IWorkload>();
1667}
1668
Derek Lamberti901ea112019-12-10 22:07:09 +00001669std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1670 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001671{
1672 return std::unique_ptr<IWorkload>();
1673}
1674
Kevin Mayce5045a2019-10-02 14:07:47 +01001675std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001676 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1677 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001678{
1679 return std::unique_ptr<IWorkload>();
1680}
1681
Derek Lamberti901ea112019-12-10 22:07:09 +00001682std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1683 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001684{
1685 return std::unique_ptr<IWorkload>();
1686}
1687
James Conroyaba90cd2020-11-06 16:28:18 +00001688std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& /*desc*/,
1689 const WorkloadInfo& /*info*/) const
1690{
1691 return std::unique_ptr<IWorkload>();
1692}
1693
1694std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1695 const WorkloadInfo& /*info*/) const
1696{
1697 return std::unique_ptr<IWorkload>();
1698}
1699
Derek Lamberti901ea112019-12-10 22:07:09 +00001700std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1701 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001702{
1703 return std::unique_ptr<IWorkload>();
1704}
1705
Derek Lamberti901ea112019-12-10 22:07:09 +00001706std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1707 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001708{
1709 return std::unique_ptr<IWorkload>();
1710}
1711
Derek Lamberti901ea112019-12-10 22:07:09 +00001712std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1713 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001714{
1715 return std::unique_ptr<IWorkload>();
1716}
1717
Derek Lamberti901ea112019-12-10 22:07:09 +00001718std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1719 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001720{
1721 return std::unique_ptr<IWorkload>();
1722}
1723
Derek Lamberti901ea112019-12-10 22:07:09 +00001724std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1725 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001726{
1727 return std::unique_ptr<IWorkload>();
1728}
1729
Derek Lamberti901ea112019-12-10 22:07:09 +00001730std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1731 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001732{
1733 return std::unique_ptr<IWorkload>();
1734}
1735
Derek Lamberti901ea112019-12-10 22:07:09 +00001736std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1737 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001738{
1739 return std::unique_ptr<IWorkload>();
1740}
1741
Derek Lamberti901ea112019-12-10 22:07:09 +00001742std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1743 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001744{
1745 return std::unique_ptr<IWorkload>();
1746}
1747
Derek Lamberti901ea112019-12-10 22:07:09 +00001748std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1749 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001750{
1751 return std::unique_ptr<IWorkload>();
1752}
1753
Derek Lamberti901ea112019-12-10 22:07:09 +00001754std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1755 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001756{
1757 return std::unique_ptr<IWorkload>();
1758}
1759
Derek Lamberti901ea112019-12-10 22:07:09 +00001760std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1761 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001762{
1763 return std::unique_ptr<IWorkload>();
1764}
1765
Derek Lamberti901ea112019-12-10 22:07:09 +00001766std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1767 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001768{
1769 return std::unique_ptr<IWorkload>();
1770}
1771
Derek Lamberti901ea112019-12-10 22:07:09 +00001772std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001773 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001774{
1775 return std::unique_ptr<IWorkload>();
1776}
1777
Derek Lamberti901ea112019-12-10 22:07:09 +00001778std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1779 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001780{
1781 return std::unique_ptr<IWorkload>();
1782}
1783
Derek Lamberti901ea112019-12-10 22:07:09 +00001784std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1785 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001786{
1787 return std::unique_ptr<IWorkload>();
1788}
1789
Derek Lamberti901ea112019-12-10 22:07:09 +00001790std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1791 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001792{
1793 return std::unique_ptr<IWorkload>();
1794}
1795
Derek Lamberti901ea112019-12-10 22:07:09 +00001796std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1797 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001798{
1799 return std::unique_ptr<IWorkload>();
1800}
1801
James Conroy586a9aa2020-03-20 08:49:33 +00001802std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
1803 const WorkloadInfo& /*info*/) const
1804{
1805 return std::unique_ptr<IWorkload>();
1806}
1807
Derek Lamberti901ea112019-12-10 22:07:09 +00001808std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1809 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001810{
1811 return std::unique_ptr<IWorkload>();
1812}
Finn Williams2605b232020-06-10 15:53:46 +01001813std::unique_ptr<IWorkload> IWorkloadFactory::CreateRank(const RankQueueDescriptor& /*descriptor*/,
1814 const WorkloadInfo& /*info*/) const
1815{
1816 return std::unique_ptr<IWorkload>();
1817}
James Conroyee18dc82019-07-17 11:27:46 +01001818
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001819std::unique_ptr<IWorkload> IWorkloadFactory::CreateReduce(const ReduceQueueDescriptor& /*descriptor*/,
1820 const WorkloadInfo& /*info*/) const
1821{
1822 return std::unique_ptr<IWorkload>();
1823}
1824
Derek Lamberti901ea112019-12-10 22:07:09 +00001825std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1826 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001827{
1828 return std::unique_ptr<IWorkload>();
1829}
1830
Derek Lamberti901ea112019-12-10 22:07:09 +00001831std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1832 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001833{
1834 return std::unique_ptr<IWorkload>();
1835}
1836
Keith Davis3ae3f972021-05-21 16:33:48 +01001837std::unique_ptr<IWorkload> IWorkloadFactory::CreateShape(const ShapeQueueDescriptor& /*descriptor*/,
1838 const WorkloadInfo& /*info*/) const
1839{
1840 return std::unique_ptr<IWorkload>();
1841}
1842
Derek Lamberti901ea112019-12-10 22:07:09 +00001843std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1844 const WorkloadInfo& /*info*/) const
1845{
1846 return std::unique_ptr<IWorkload>();
1847}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001848
Derek Lamberti901ea112019-12-10 22:07:09 +00001849std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1850 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001851{
1852 return std::unique_ptr<IWorkload>();
1853}
1854
Derek Lamberti901ea112019-12-10 22:07:09 +00001855std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1856 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001857{
1858 return std::unique_ptr<IWorkload>();
1859}
1860
Derek Lamberti901ea112019-12-10 22:07:09 +00001861std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1862 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001863{
1864 return std::unique_ptr<IWorkload>();
1865}
1866
Derek Lamberti901ea112019-12-10 22:07:09 +00001867std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1868 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001869{
1870 return std::unique_ptr<IWorkload>();
1871}
1872
Derek Lamberti901ea112019-12-10 22:07:09 +00001873std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1874 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001875{
1876 return std::unique_ptr<IWorkload>();
1877}
1878
Derek Lamberti901ea112019-12-10 22:07:09 +00001879std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1880 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001881{
1882 return std::unique_ptr<IWorkload>();
1883}
1884
Derek Lamberti901ea112019-12-10 22:07:09 +00001885std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1886 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001887{
1888 return std::unique_ptr<IWorkload>();
1889}
1890
Derek Lamberti901ea112019-12-10 22:07:09 +00001891std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1892 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001893{
1894 return std::unique_ptr<IWorkload>();
1895}
1896
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001897std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1898 const WorkloadInfo& /*info*/) const
1899{
1900 return std::unique_ptr<IWorkload>();
1901}
1902
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001903std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001904 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1905 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001906{
1907 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001908}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001909
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001910std::unique_ptr<IWorkload> IWorkloadFactory::CreateUnidirectionalSequenceLstm(
1911 const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/,
1912 const WorkloadInfo& /*info*/) const
1913{
1914 return std::unique_ptr<IWorkload>();
1915}
1916
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001917} // namepsace armnn