blob: 25334c3b5247db220e5e3f408479f94893c4f0f7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
Matteo Martincighe011d202019-11-28 11:35:47 +000012#include <LayerSupportCommon.hpp>
13
Derek Lambertif674aa02019-08-01 15:56:25 +010014#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000015
Matteo Martincighe011d202019-11-28 11:35:47 +000016#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000020#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
James Conroy4d1ff582019-06-10 17:06:39 +010049namespace
50{
51
52std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
53 unsigned int actual,
54 std::string& layerStr,
55 std::string& tensorName)
56{
57 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
58 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
59
60 return errorMsg;
61}
62
63} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000064
Sadik Armagan9199e582019-09-05 17:35:31 +010065bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
66 Optional<std::string&> reasonIfUnsupported) const
67{
josh minor4a3c6102020-01-06 16:40:46 -060068 return IsElementwiseUnarySupported(input,
69 output,
70 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
71 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010072}
73
arovir011c7c81b2018-10-08 11:34:28 +010074bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
75 const TensorInfo& output,
76 const ActivationDescriptor& descriptor,
77 Optional<std::string&> reasonIfUnsupported) const
78{
Derek Lamberti50db4e82019-03-13 14:16:15 +000079 bool supported = true;
80
81 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000082 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +000083 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010084 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000085 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000086 DataType::QAsymmU8,
87 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000088 };
89
90 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
91 "Reference activation: input type not supported.");
92
93 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
94 "Reference activation: output type not supported.");
95
96 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
97 "Reference activation: input and output types mismatched.");
98
99 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
100 "Reference activation: input and output shapes are of different rank.");
101
102
103 struct ActivationFunctionSupported : public Rule
104 {
105 ActivationFunctionSupported(const ActivationDescriptor& desc)
106 {
107 switch(desc.m_Function)
108 {
109 case ActivationFunction::Abs:
110 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000111 case ActivationFunction::Elu:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000112 case ActivationFunction::LeakyReLu:
113 case ActivationFunction::Linear:
114 case ActivationFunction::ReLu:
115 case ActivationFunction::Sigmoid:
116 case ActivationFunction::SoftReLu:
117 case ActivationFunction::Sqrt:
118 case ActivationFunction::Square:
119 case ActivationFunction::TanH:
120 {
121 m_Res = true;
122 break;
123 }
124 default:
125 {
126 m_Res = false;
127 break;
128 }
129 }
130 }
131 };
132
133 // Function is supported
134 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
135 "Reference activation: function not supported.");
136
137 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100138}
139
140bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
141 const TensorInfo& input1,
142 const TensorInfo& output,
143 Optional<std::string&> reasonIfUnsupported) const
144{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000145 bool supported = true;
146
Keith Davis0c2eeac2020-02-11 16:51:50 +0000147 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000148 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100149 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000150 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000151 DataType::QAsymmU8,
152 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000153 };
154
155 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
156 "Reference addition: input 0 is not a supported type.");
157
158 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
159 "Reference addition: input 1 is not a supported type.");
160
161 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
162 "Reference addition: output is not a supported type.");
163
164 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
165 "Reference addition: input 0 and Input 1 types are mismatched");
166
167 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
168 "Reference addition: input and output types are mismatched");
169
170 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
171 "Reference addition: shapes are not suitable for implicit broadcast.");
172
173 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100174}
175
Nikhil Raj68c2c902019-09-19 11:21:11 +0100176bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
177 const armnn::ArgMinMaxDescriptor &descriptor,
178 armnn::Optional<std::string &> reasonIfUnsupported) const
179{
180 ignore_unused(descriptor);
181
Francis Murtagh1939df52019-11-13 15:21:09 +0000182 std::array<DataType, 4> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100183 {
184 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000185 DataType::QAsymmU8,
186 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000187 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100188 };
189
190 bool supported = true;
191
192 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
193 "Reference ArgMinMax: input is not a supported type.");
194 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
195 "Reference ArgMinMax: output type not supported");
196
197 return supported;
198}
199
arovir011c7c81b2018-10-08 11:34:28 +0100200bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
201 const TensorInfo& output,
202 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100203 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100204 const TensorInfo& beta,
205 const TensorInfo& gamma,
206 const BatchNormalizationDescriptor& descriptor,
207 Optional<std::string&> reasonIfUnsupported) const
208{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100209 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100210
Matthew Jackson9bff1442019-09-12 09:08:23 +0100211 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100212 {
213 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100214 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000215 DataType::QAsymmU8,
216 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100217 };
218
219 bool supported = true;
220
221 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
222 "Reference batch normalization: input is not a supported type.");
223
224 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
225 "Reference batch normalization: output is not a supported type.");
226
227 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
228 "Reference batch normalization: input and output types are mismatched");
229
230 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
231 "Reference batch normalization: mean is not a supported type.");
232
233 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
234 "Reference batch normalization: variance is not a supported type.");
235
236 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
237 "Reference batch normalization: beta is not a supported type.");
238
239 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
240 "Reference batch normalization: gamma is not a supported type.");
241
242 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100243}
244
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000245bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
246 const TensorInfo& output,
247 const BatchToSpaceNdDescriptor& descriptor,
248 Optional<std::string&> reasonIfUnsupported) const
249{
250 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100251
252 bool supported = true;
253
254 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
255 std::string inputTensorStr = "input";
256 std::string outputTensorStr = "output";
257
258 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100259 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100260 {
261 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100262 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000263 DataType::QAsymmU8,
264 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100265 };
266
267 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
268 "Reference BatchToSpaceNd: input type not supported.");
269
270 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
271 "Reference BatchToSpaceNd: output type not supported.");
272
273 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
274 "Reference BatchToSpaceNd: input and output types mismatched.");
275
276 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
277 reasonIfUnsupported,
278 CreateIncorrectDimensionsErrorMsg(4,
279 output.GetNumDimensions(),
280 batchToSpaceNdLayerStr,
281 outputTensorStr).data());
282
283 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
284 reasonIfUnsupported,
285 CreateIncorrectDimensionsErrorMsg(4,
286 input.GetNumDimensions(),
287 batchToSpaceNdLayerStr,
288 inputTensorStr).data());
289
290 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000291}
292
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100293bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
294 const TensorInfo& input1,
295 const TensorInfo& output,
296 const ComparisonDescriptor& descriptor,
297 Optional<std::string&> reasonIfUnsupported) const
298{
299 boost::ignore_unused(descriptor);
300
301 std::array<DataType, 4> supportedInputTypes =
302 {
303 DataType::Float32,
304 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000305 DataType::QAsymmU8,
306 DataType::QSymmS16
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100307 };
308
309 bool supported = true;
310 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
311 "Reference comparison: input 0 is not a supported type");
312
313 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
314 "Reference comparison: input 0 and Input 1 types are mismatched");
315
316 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
317 "Reference comparison: output is not of type Boolean");
318
319 return supported;
320}
321
Jim Flynn906f9462019-05-10 13:55:21 +0100322bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
323 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100324 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100325 Optional<std::string&> reasonIfUnsupported) const
326{
Jim Flynne242f2d2019-05-22 14:24:13 +0100327 ignore_unused(descriptor);
328
329 bool supported = true;
Keith Davis5204aa82020-01-27 15:24:59 +0000330 std::array<DataType,5> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100331 {
332 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100333 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000334 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +0000335 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000336 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100337 };
338
339 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
340 "Reference concatenation: output type not supported");
341 for (const TensorInfo* input : inputs)
342 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100343 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100344 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
345 "Reference concatenation: input type not supported");
346
347 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
348 "Reference concatenation: input and output types mismatched.");
349 }
350
351 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100352}
353
arovir011c7c81b2018-10-08 11:34:28 +0100354bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
355 Optional<std::string&> reasonIfUnsupported) const
356{
Keith Davis67e6c542020-02-19 10:08:33 +0000357 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100358 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100359 DataType::Float32,
360 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000361 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +0000362 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000363 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000364 DataType::QSymmS16
Nina Drozd58ef2c62019-05-16 12:09:18 +0100365 };
366
367 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
368 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100369}
370
371bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
372 const TensorInfo& output,
373 Optional<std::string&> reasonIfUnsupported) const
374{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100375 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
376 input.GetDataType(),
377 &TrueFunc<>,
378 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000379 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000380 &FalseFuncI32<>,
381 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100382 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
383 output.GetDataType(),
384 &FalseOutputFuncF16<>,
385 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000386 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000387 &FalseFuncI32<>,
388 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100389}
390
391bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
392 const TensorInfo& output,
393 Optional<std::string&> reasonIfUnsupported) const
394{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100395 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
396 input.GetDataType(),
397 &FalseInputFuncF16<>,
398 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000399 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000400 &FalseFuncI32<>,
401 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100402 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
403 output.GetDataType(),
404 &TrueFunc<>,
405 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000406 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000407 &FalseFuncI32<>,
408 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100409}
410
411bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
412 const TensorInfo& output,
413 const Convolution2dDescriptor& descriptor,
414 const TensorInfo& weights,
415 const Optional<TensorInfo>& biases,
416 Optional<std::string&> reasonIfUnsupported) const
417{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100418 bool supported = true;
419
420 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000421 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000422 {
423 DataType::Float32,
424 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000425 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000426 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000427 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000428 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100429 };
430
431 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000432 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100433
434 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000435 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100436
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100437 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000438 "Reference Convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100439
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000440 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000441 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000442 {
Derek Lambertid466a542020-01-22 15:37:29 +0000443 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000444 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000445 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000446 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000447 DataType::QSymmS8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000448 DataType::QAsymmS8,
Derek Lambertid466a542020-01-22 15:37:29 +0000449 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000450 };
Derek Lambertid466a542020-01-22 15:37:29 +0000451 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000452
453 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000454 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000455 }
456 else
457 {
458 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000459 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000460
461 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000462 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000463 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100464
465 if (biases.has_value())
466 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000467 std::array<DataType,3> biasesSupportedTypes =
468 {
469 DataType::Float32,
470 DataType::Float16,
471 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100472 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000473
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100474 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000475 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100476 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100477 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100478
479 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100480}
481
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000482bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
483 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000484 Optional<std::string&> reasonIfUnsupported) const
485{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100486 bool supported = true;
487
Keith Davis0c2eeac2020-02-11 16:51:50 +0000488 std::array<DataType, 7> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100489 {
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000490 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100491 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000492 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000493 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000494 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000495 DataType::QSymmS16,
496 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100497 };
498
499 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000500 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100501
502 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000503 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100504
505 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000506 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100507
508 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000509}
510
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100511bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
512 const TensorInfo& output,
513 const DepthToSpaceDescriptor& descriptor,
514 Optional<std::string&> reasonIfUnsupported) const
515{
516 ignore_unused(descriptor);
517 bool supported = true;
518
519 std::array<DataType,4> supportedTypes =
520 {
521 DataType::Float32,
522 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000523 DataType::QAsymmU8,
524 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100525 };
526
527 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
528 "Reference DepthToSpace: input type not supported");
529
530 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
531 "Reference DepthToSpace: output type not supported");
532
533 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
534 "Reference DepthToSpace: input and output types are mismatched");
535
536 return supported;
537}
538
arovir011c7c81b2018-10-08 11:34:28 +0100539bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
540 const TensorInfo& output,
541 const DepthwiseConvolution2dDescriptor& descriptor,
542 const TensorInfo& weights,
543 const Optional<TensorInfo>& biases,
544 Optional<std::string&> reasonIfUnsupported) const
545{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100546 bool supported = true;
547
548 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000549 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100550 {
551 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100552 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000553 DataType::QSymmS8,
554 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000555 DataType::QAsymmU8,
556 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100557 };
558
559 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
560 "Reference DepthwiseConvolution2d: input is not a supported type.");
561
562 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
563 "Reference DepthwiseConvolution2d: output is not a supported type.");
564
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100565 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
566 "Reference DepthwiseConvolution2d: input and output types mismatched.");
567
Derek Lambertid466a542020-01-22 15:37:29 +0000568 ARMNN_NO_DEPRECATE_WARN_BEGIN
569 std::array<DataType, 3> supportedWeightTypes =
570 {
571 DataType::QAsymmU8,
572 DataType::QSymmS8,
573 DataType::QuantizedSymm8PerAxis // deprecated
574 };
575 ARMNN_NO_DEPRECATE_WARN_END
576
Teresa Charlind8df0262019-11-11 12:28:15 +0000577 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000578 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000579 {
Teresa Charlind8df0262019-11-11 12:28:15 +0000580
581 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
582 "Reference convolution2d: weights type not supported for quantized input.");
583 }
584 else
585 {
586 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
587 "Reference DepthwiseConvolution2d: weights is not a supported type.");
588
589 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
590 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
591 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100592
593 if (biases.has_value())
594 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100595 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100596 {
597 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100598 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100599 DataType::Signed32
600 };
601 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
602 "Reference DepthwiseConvolution2d: biases is not a supported type.");
603 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100604 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100605
606 return supported;
607
arovir011c7c81b2018-10-08 11:34:28 +0100608}
609
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000610bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
611 const TensorInfo& output,
612 Optional<std::string&> reasonIfUnsupported) const
613{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100614 bool supported = true;
615
Ryan OShea9add1202020-02-07 10:06:33 +0000616 std::array<DataType,4> supportedInputTypes = {
617 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000618 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000619 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000620 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100621 };
622
623 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000624 "Reference for Dequantize layer: input type not supported.");
625
626 supported &= CheckSupportRule( TypeNotPerAxisQuantized(input), reasonIfUnsupported,
627 "Reference for Dequantize layer: per-axis quantized input not support .");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100628
Derek Lambertid466a542020-01-22 15:37:29 +0000629 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
630 "Reference dequantize: per-axis quantized input not support .");
631
Jan Eilersf7107932019-11-01 11:09:36 +0000632 std::array<DataType,2> supportedOutputTypes = {
633 DataType::Float32,
634 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100635 };
636
637 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000638 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100639
640 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000641 "Reference for Dequantize layer: input/output shapes have different num total "
642 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100643
644 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000645}
646
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000647bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
648 const TensorInfo& scores,
649 const TensorInfo& anchors,
650 const TensorInfo& detectionBoxes,
651 const TensorInfo& detectionClasses,
652 const TensorInfo& detectionScores,
653 const TensorInfo& numDetections,
654 const DetectionPostProcessDescriptor& descriptor,
655 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000656{
Derek Lamberti901ea112019-12-10 22:07:09 +0000657 boost::ignore_unused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
658
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100659 bool supported = true;
660
Mike Kelly4992c342019-08-14 11:33:11 +0100661 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100662 {
663 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000664 DataType::QAsymmU8,
665 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100666 };
667
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000668 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100669 "Reference DetectionPostProcess: input 0 is not a supported type.");
670
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000671 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100672 "Reference DetectionPostProcess: input 1 is not a supported type.");
673
674 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000675}
676
Pablo Tellof0bd6832019-04-26 17:58:13 +0100677bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
678 const TensorInfo& output,
679 const DepthwiseConvolution2dDescriptor& descriptor,
680 const TensorInfo& weights,
681 const Optional<TensorInfo>& biases,
682 Optional<std::string&> reasonIfUnsupported) const
683{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100684 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100685}
686
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100687bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100688 const TensorInfo& input1,
689 const TensorInfo& output,
690 Optional<std::string&> reasonIfUnsupported) const
691{
Sadik Armagan2999a022019-04-09 14:20:12 +0100692 bool supported = true;
693
Matthew Jackson9bff1442019-09-12 09:08:23 +0100694 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100695 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100696 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000697 DataType::QAsymmU8,
698 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +0100699 };
700
701 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
702 "Reference division: input 0 is not a supported type.");
703
704 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
705 "Reference division: input 1 is not a supported type.");
706
707 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
708 "Reference division: output is not a supported type.");
709
710 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
711 "Reference division: input 0 and Input 1 types are mismatched");
712
713 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
714 "Reference division: input and output types are mismatched");
715
716 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
717 "Reference division: shapes are not suitable for implicit broadcast.");
718
719 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100720}
721
josh minor4a3c6102020-01-06 16:40:46 -0600722bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
723 const TensorInfo& output,
724 const ElementwiseUnaryDescriptor& descriptor,
725 Optional<std::string&> reasonIfUnsupported) const
726{
727 boost::ignore_unused(descriptor);
728
729 std::array<DataType, 4> supportedTypes =
730 {
731 DataType::Float32,
732 DataType::Float16,
733 DataType::QAsymmU8,
734 DataType::QSymmS16
735 };
736
737 bool supported = true;
738
739 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
740 "Reference elementwise unary: input type not supported");
741
742 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
743 "Reference elementwise unary: output type not supported");
744
745 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
746 "Reference elementwise unary: input and output types not matching");
747
748 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
749 "Reference elementwise unary: input and output shapes"
750 "have different number of total elements");
751
752 return supported;
753}
754
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000755bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
756 const TensorInfo& input1,
757 const TensorInfo& output,
758 Optional<std::string&> reasonIfUnsupported) const
759{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100760 return IsComparisonSupported(input0,
761 input1,
762 output,
763 ComparisonDescriptor(ComparisonOperation::Equal),
764 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000765}
766
arovir011c7c81b2018-10-08 11:34:28 +0100767bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
768 const FakeQuantizationDescriptor& descriptor,
769 Optional<std::string&> reasonIfUnsupported) const
770{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100771 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100772 bool supported = true;
773
774 std::array<DataType,1> supportedTypes =
775 {
776 DataType::Float32
777 };
778
779 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
780 "Reference fake quantization: input type not supported.");
781
782 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100783}
784
785bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
786 const TensorInfo& output,
787 Optional<std::string&> reasonIfUnsupported) const
788{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100789 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100790 bool supported = true;
791
Matthew Jackson9bff1442019-09-12 09:08:23 +0100792 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100793 {
James Conroyb40d7102019-06-04 12:32:09 +0100794 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100795 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000796 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +0100797 };
798
799 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
800 "Reference Floor: input type not supported.");
801
802 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
803 "Reference Floor: output type not supported.");
804
805 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100806}
807
808bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
809 const TensorInfo& output,
810 const TensorInfo& weights,
811 const TensorInfo& biases,
812 const FullyConnectedDescriptor& descriptor,
813 Optional<std::string&> reasonIfUnsupported) const
814{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100815 bool supported = true;
816
817 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100818 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100819 {
820 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100821 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000822 DataType::QAsymmU8,
823 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100824 };
825
826 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
827 "Reference Fully Connected: input type not supported.");
828
829 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
830 "Reference Fully Connected: output type not supported.");
831
832 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
833 "Reference Fully Connected: input and output types mismatched.");
834
835 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
836 "Reference Fully Connected: weights type not supported.");
837
838 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
839 "Reference Fully Connected: input and weight types mismatched.");
840
841 if (descriptor.m_BiasEnabled)
842 {
843 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100844 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100845 supportedBiasTypes =
846 {
847 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100848 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100849 DataType::Signed32
850 };
851
852 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
853 "Reference Fully Connected: bias type not supported.");
854
855 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
856 "Reference Fully Connected: bias and weight types mismatch.");
857
858 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
859 "Reference Fully Connected: bias type inferred from weights is incompatible.");
860
861 }
862
863 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100864}
865
narpra014951d842019-01-18 16:53:53 +0000866bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
867 const armnn::TensorInfo& input1,
868 const armnn::TensorInfo& output,
869 armnn::Optional<std::string&> reasonIfUnsupported) const
870{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100871 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100872 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100873 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100874 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100875 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000876 DataType::QAsymmU8,
877 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100878 };
879
880 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
881 "Reference Gather: input type not supported");
882
883 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
884 "Reference Gather: output type not supported");
885
886 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
887 "Reference Gather: indices (input1) type not supported");
888
889 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
890 "Reference Gather: input and output types not matching");
891
892 return supported;
narpra014951d842019-01-18 16:53:53 +0000893}
894
FrancisMurtagh878f0232018-12-19 10:56:15 +0000895bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
896 const TensorInfo& input1,
897 const TensorInfo& output,
898 Optional<std::string&> reasonIfUnsupported) const
899{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100900 return IsComparisonSupported(input0,
901 input1,
902 output,
903 ComparisonDescriptor(ComparisonOperation::Greater),
904 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000905}
906
Derek Lamberti901ea112019-12-10 22:07:09 +0000907bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
908 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +0100909{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100910 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100911}
912
Kevin May09ca49c2019-10-09 12:37:34 +0100913bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
914 const TensorInfo& output,
915 const InstanceNormalizationDescriptor& descriptor,
916 Optional<std::string&> reasonIfUnsupported) const
917{
918 ignore_unused(descriptor);
919 // Define supported types
920 std::array<DataType, 4> supportedTypes =
921 {
922 DataType::Float32,
923 DataType::Float16
924 };
925
926 bool supported = true;
927
928 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
929 "Reference Instance Normalization: input type not supported.");
930
931 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
932 "Reference Instance Normalization: output type not supported.");
933
934 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
935 "Reference Instance Normalization: input and output types mismatched.");
936
937 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
938 "Reference Instance Normalization: input and output shapes have different "
939 "num total elements.");
940
941 return supported;
942}
943
arovir011c7c81b2018-10-08 11:34:28 +0100944bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
945 const TensorInfo& output,
946 const L2NormalizationDescriptor& descriptor,
947 Optional<std::string&> reasonIfUnsupported) const
948{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100949 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100950 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100951 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100952 {
953 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100954 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000955 DataType::QAsymmU8,
956 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100957 };
958
959 bool supported = true;
960
961 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
962 "Reference L2normalization: input type not supported.");
963
964 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
965 "Reference L2normalization: output type not supported.");
966
967 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
968 "Reference L2normalization: input and output types mismatched.");
969
970 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
971 "Reference L2normalization: input and output shapes have different "
972 "num total elements.");
973
974 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100975}
976
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100977bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
978 const TensorInfo& output,
979 const LogSoftmaxDescriptor& descriptor,
980 Optional<std::string&> reasonIfUnsupported) const
981{
982 ignore_unused(descriptor);
983
984 std::array<DataType, 2> supportedTypes =
985 {
986 DataType::Float32,
987 DataType::Float16
988 };
989
990 bool supported = true;
991 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
992 "Reference LogSoftmax: input type not supported");
993
994 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
995 "Reference LogSoftmax: output type not supported");
996
997 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
998 "Reference LogSoftmax: input and output types do not match");
999
1000 return supported;
1001}
1002
arovir011c7c81b2018-10-08 11:34:28 +01001003bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1004 const TensorInfo& outputStateIn,
1005 const TensorInfo& cellStateIn,
1006 const TensorInfo& scratchBuffer,
1007 const TensorInfo& outputStateOut,
1008 const TensorInfo& cellStateOut,
1009 const TensorInfo& output,
1010 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001011 const LstmInputParamsInfo& paramsInfo,
1012 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001013{
telsoa01c577f2c2018-08-31 09:22:23 +01001014 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +01001015 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001016
1017 bool supported = true;
1018
1019 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001020 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001021 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001022 };
1023
Jan Eilersd01a83c2019-07-03 18:20:40 +01001024 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001025 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1026 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001027 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1028 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001029 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1030 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001031 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1032 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001033 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1034 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001035 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1036 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001037 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1038 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001039 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001040 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001041 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001042 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001043 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001044 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001045 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001046 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001047 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001048 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001049 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001050 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001051 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001052 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001053 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001054 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001055 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001056 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001057 "Reference Lstm: input and OutputGateBias types are mismatched");
1058 if (!descriptor.m_CifgEnabled)
1059 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001060 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001061 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001062 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001063 reasonIfUnsupported,
1064 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001065 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001066 "Reference Lstm: input and InputGateBias types are mismatched");
1067 if (descriptor.m_PeepholeEnabled)
1068 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001069 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001070 reasonIfUnsupported,
1071 "Reference Lstm: input and CellToInputWeights types are mismatched");
1072 }
1073 }
1074 if (descriptor.m_PeepholeEnabled)
1075 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001076 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001077 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001078 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001079 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1080 }
1081 if (descriptor.m_ProjectionEnabled)
1082 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001083 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001084 "Reference Lstm: input and mProjectionWeights types are mismatched");
1085 if (paramsInfo.m_ProjectionBias != nullptr)
1086 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001087 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001088 "Reference Lstm: input and ProjectionBias types are mismatched");
1089 }
1090 }
1091 if (descriptor.m_LayerNormEnabled)
1092 {
1093 if (!descriptor.m_CifgEnabled)
1094 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001095 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001096 reasonIfUnsupported,
1097 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1098 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001099 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001100 reasonIfUnsupported,
1101 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001102 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001103 reasonIfUnsupported,
1104 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001105 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001106 reasonIfUnsupported,
1107 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1108 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001109
1110 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001111}
1112
saoste012df12b32018-11-28 16:57:20 +00001113bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1114 const TensorInfo& input1,
1115 const TensorInfo& output,
1116 Optional<std::string&> reasonIfUnsupported) const
1117{
Sadik Armagan2999a022019-04-09 14:20:12 +01001118 bool supported = true;
1119
Keith Davis5204aa82020-01-27 15:24:59 +00001120 std::array<DataType,5> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001121 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001122 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001123 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001124 DataType::QAsymmU8,
1125 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001126 };
1127
1128 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1129 "Reference maximum: input 0 is not a supported type.");
1130
1131 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1132 "Reference maximum: input 1 is not a supported type.");
1133
1134 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1135 "Reference maximum: output is not a supported type.");
1136
1137 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1138 "Reference maximum: input 0 and Input 1 types are mismatched");
1139
1140 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1141 "Reference maximum: input and output types are mismatched");
1142
1143 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1144 "Reference maximum: shapes are not suitable for implicit broadcast.");
1145
1146 return supported;
saoste012df12b32018-11-28 16:57:20 +00001147}
1148
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001149bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1150 const TensorInfo& output,
1151 const MeanDescriptor& descriptor,
1152 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001153{
James Conroy4d1ff582019-06-10 17:06:39 +01001154 bool supported = true;
1155 std::string meanLayerStr = "Mean";
1156 std::string outputTensorStr = "output";
1157
Matthew Jackson252df3a2019-09-11 09:19:18 +01001158 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001159 {
1160 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001161 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001162 DataType::QAsymmU8,
1163 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001164 };
1165
1166 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1167 "Reference Mean: input type not supported.");
1168
1169 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1170 "Reference Mean: input and output types are mismatched");
1171
1172 if (descriptor.m_KeepDims)
1173 {
1174 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1175 reasonIfUnsupported,
1176 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1177 output.GetNumDimensions(),
1178 meanLayerStr, outputTensorStr).data());
1179 }
1180 else if (descriptor.m_Axis.empty())
1181 {
1182 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1183 reasonIfUnsupported,
1184 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1185 meanLayerStr, outputTensorStr).data());
1186 }
1187 else
1188 {
1189 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1190
1191 if (outputDim > 0)
1192 {
1193 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1194 reasonIfUnsupported,
1195 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1196 meanLayerStr, outputTensorStr).data());
1197 }
1198 else
1199 {
1200 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1201 reasonIfUnsupported,
1202 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1203 meanLayerStr, outputTensorStr).data());
1204 }
1205 }
1206
1207 return supported;
narpra0132b90462018-09-13 11:07:48 +01001208}
1209
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001210bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001211 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001212 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001213 Optional<std::string&> reasonIfUnsupported) const
1214{
Jim Flynne242f2d2019-05-22 14:24:13 +01001215 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001216}
1217
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001218bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1219 const TensorInfo &output,
1220 Optional<std::string &> reasonIfUnsupported) const
1221{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001222 bool supported = true;
1223
1224 std::array<DataType,5> supportedTypes =
1225 {
1226 DataType::Float32,
1227 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001228 DataType::QAsymmU8,
1229 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001230 DataType::Boolean
1231 };
1232
1233 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1234 "Reference MemCopy: input type not supported");
1235
1236 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1237 "Reference MemCopy: output type not supported");
1238
1239 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1240 "Reference MemCopy: input and output types are mismatched");
1241
1242 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001243}
1244
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001245bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1246 const TensorInfo& input1,
1247 const TensorInfo& output,
1248 Optional<std::string&> reasonIfUnsupported) const
1249{
Sadik Armagan2999a022019-04-09 14:20:12 +01001250 bool supported = true;
1251
Matthew Jackson9bff1442019-09-12 09:08:23 +01001252 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001253 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001254 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001255 DataType::QAsymmU8,
1256 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001257 };
1258
1259 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1260 "Reference minimum: input 0 is not a supported type.");
1261
1262 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1263 "Reference minimum: input 1 is not a supported type.");
1264
1265 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1266 "Reference minimum: output is not a supported type.");
1267
1268 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1269 "Reference minimum: input 0 and Input 1 types are mismatched");
1270
1271 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1272 "Reference minimum: input and output types are mismatched");
1273
1274 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1275 "Reference minimum: shapes are not suitable for implicit broadcast.");
1276
1277 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001278}
1279
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001280bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1281 const TensorInfo& input1,
1282 const TensorInfo& output,
1283 Optional<std::string&> reasonIfUnsupported) const
1284{
Sadik Armagan2999a022019-04-09 14:20:12 +01001285 bool supported = true;
1286
Keith Davis67e6c542020-02-19 10:08:33 +00001287 std::array<DataType,6> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001288 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001289 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001290 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001291 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001292 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001293 };
1294
1295 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1296 "Reference multiplication: input 0 is not a supported type.");
1297
1298 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1299 "Reference multiplication: input 1 is not a supported type.");
1300
1301 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1302 "Reference multiplication: output is not a supported type.");
1303
1304 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1305 "Reference multiplication: input 0 and Input 1 types are mismatched");
1306
1307 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1308 "Reference multiplication: input and output types are mismatched");
1309
1310 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1311 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1312
1313 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001314}
1315
1316bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1317 const TensorInfo& output,
1318 const NormalizationDescriptor& descriptor,
1319 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001320{
Nina Drozd661dfa72018-10-02 11:14:17 +01001321 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001322
1323 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001324 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001325 {
1326 DataType::Float16,
1327 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001328 DataType::QAsymmU8,
1329 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001330 };
1331
1332 bool supported = true;
1333
1334 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1335 "Reference normalization: input type not supported.");
1336
1337 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1338 "Reference normalization: output type not supported.");
1339
1340 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1341 "Reference normalization: input and output shapes have different "
1342 "num total elements.");
1343
1344 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001345}
1346
Derek Lamberti901ea112019-12-10 22:07:09 +00001347bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1348 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001349{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001350 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001351}
1352
1353bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1354 const TensorInfo& output,
1355 const PadDescriptor& descriptor,
1356 Optional<std::string&> reasonIfUnsupported) const
1357{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001358 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001359 bool supported = true;
1360
1361 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001362 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001363 {
1364 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001365 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001366 DataType::QAsymmU8,
1367 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001368 };
1369
1370 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1371 "Reference pad: input is not a supported type.");
1372
1373 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1374 "Reference pad: output is not a supported type.");
1375
1376 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1377 "Reference pad: input and output types are mismatched.");
1378
1379 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001380}
1381
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001382bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1383 const TensorInfo& output,
1384 const PermuteDescriptor& descriptor,
1385 Optional<std::string&> reasonIfUnsupported) const
1386{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001387 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001388 bool supported = true;
1389
1390 // Define supported output and inputs types.
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001391 std::array<DataType, 4> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001392 {
1393 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001394 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001395 DataType::QAsymmU8,
1396 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001397 };
1398
1399 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1400 "Reference permute: input is not a supported type.");
1401
1402 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1403 "Reference permute: output is not a supported type.");
1404
1405 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1406 "Reference permute: input and output types are mismatched.");
1407
1408 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001409}
1410
1411bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1412 const TensorInfo& output,
1413 const Pooling2dDescriptor& descriptor,
1414 Optional<std::string&> reasonIfUnsupported) const
1415{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001416 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001417 bool supported = true;
1418
1419 // Define supported output and inputs types.
Keith Davis67e6c542020-02-19 10:08:33 +00001420 std::array<DataType,5> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001421 {
1422 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001423 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001424 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001425 DataType::QAsymmU8,
1426 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001427 };
1428
1429 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1430 "Reference poolind2d: input is not a supported type.");
1431
1432 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1433 "Reference poolind2d: output is not a supported type.");
1434
1435 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1436 "Reference poolind2d: input and output types are mismatched.");
1437
1438 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001439}
1440
Derek Lamberti5f400d62019-03-25 15:41:58 +00001441bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1442 const TensorInfo& output,
1443 Optional<std::string&> reasonIfUnsupported) const
1444{
1445 bool supported = true;
1446
Finn Williamsfd271062019-12-04 14:27:27 +00001447 // Define supported input types.
Ryan OShea9add1202020-02-07 10:06:33 +00001448 std::array<DataType,6> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00001449 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001450 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001451 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001452 DataType::QAsymmU8,
1453 DataType::QSymmS8,
1454 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001455 };
1456
1457 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1458 "Reference quantize: input type not supported.");
1459
1460 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001461 std::array<DataType,4> supportedOutputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001462 DataType::QAsymmU8,
Ryan OShea9add1202020-02-07 10:06:33 +00001463 DataType::QAsymmS8,
Finn Williamsfd271062019-12-04 14:27:27 +00001464 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001465 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001466 };
1467 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1468 "Reference quantize: output type not supported.");
1469
1470 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1471 "Reference quantize: input and output shapes have different num total elements.");
1472
1473 return supported;
1474}
1475
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001476bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001477 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001478 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001479 Optional<std::string&> reasonIfUnsupported) const
1480{
Kevin Maya023c402019-12-12 17:28:05 +00001481 ignore_unused(output);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001482 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001483 // Define supported output types.
Keith Davis0c2eeac2020-02-11 16:51:50 +00001484 std::array<DataType,7> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001485 {
1486 DataType::Float32,
1487 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001488 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001489 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001490 DataType::QAsymmU8,
1491 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001492 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001493
Nina Drozd2f2778f2019-05-27 10:37:05 +01001494 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1495 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001496}
1497
1498bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001499 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001500 Optional<std::string&> reasonIfUnsupported) const
1501{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001502 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001503 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001504 {
1505 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001506 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001507 DataType::QAsymmU8,
1508 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001509 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001510
1511 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1512 "Reference ResizeBilinear: input type not supported");
1513
1514 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1515 "Reference ResizeBilinear: output type not supported");
1516
1517 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1518 "Reference ResizeBilinear: input and output types not matching");
1519
1520 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001521}
1522
Teresa Charlin970f43b2019-07-01 13:51:07 +01001523bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1524 const TensorInfo& output,
1525 const ResizeDescriptor& descriptor,
1526 Optional<std::string&> reasonIfUnsupported) const
1527{
Derek Lamberti901ea112019-12-10 22:07:09 +00001528 boost::ignore_unused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001529 bool supported = true;
Keith Davis5204aa82020-01-27 15:24:59 +00001530 std::array<DataType,5> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001531 {
1532 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001533 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001534 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001535 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001536 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001537 };
1538
1539 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1540 "Reference Resize: input type not supported");
1541
1542 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1543 "Reference Resize: output type not supported");
1544
1545 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1546 "Reference Resize: input and output types not matching");
1547
1548 return supported;
1549}
1550
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001551bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1552 const TensorInfo& output,
1553 Optional<std::string&> reasonIfUnsupported) const
1554{
josh minor4a3c6102020-01-06 16:40:46 -06001555 return IsElementwiseUnarySupported(input,
1556 output,
1557 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1558 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001559}
1560
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001561bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1562 const TensorInfo& output,
1563 const SliceDescriptor& descriptor,
1564 Optional<std::string&> reasonIfUnsupported) const
1565{
Derek Lamberti901ea112019-12-10 22:07:09 +00001566 boost::ignore_unused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001567 bool supported = true;
1568
1569 std::array<DataType, 3> supportedTypes =
1570 {
1571 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001572 DataType::QAsymmU8,
1573 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001574 };
1575
1576 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1577 "Reference Slice: input type not supported");
1578
1579 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1580 "Reference Slice: output type not supported");
1581
1582 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1583 "Reference Slice: input and output types are mismatched");
1584
1585 return supported;
1586}
1587
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001588bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1589 const TensorInfo& output,
1590 const SoftmaxDescriptor& descriptor,
1591 Optional<std::string&> reasonIfUnsupported) const
1592{
Derek Lamberti901ea112019-12-10 22:07:09 +00001593 boost::ignore_unused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001594 bool supported = true;
Keith Davis0c2eeac2020-02-11 16:51:50 +00001595 std::array<DataType,6> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001596 {
1597 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001598 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001599 DataType::QSymmS8,
1600 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001601 DataType::QAsymmU8,
1602 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001603 };
1604
1605 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001606 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001607
1608 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001609 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001610
1611 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001612 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001613
1614 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001615}
1616
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001617bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1618 const TensorInfo& output,
1619 const SpaceToBatchNdDescriptor& descriptor,
1620 Optional<std::string&> reasonIfUnsupported) const
1621{
Derek Lamberti901ea112019-12-10 22:07:09 +00001622 boost::ignore_unused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001623 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001624 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001625 {
1626 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001627 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001628 DataType::QAsymmU8,
1629 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001630 };
1631
1632 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1633 "Reference SpaceToBatchNd: input type not supported");
1634
1635 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1636 "Reference SpaceToBatchNd: output type not supported");
1637
1638 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1639 "Reference SpaceToBatchNd: input and output types are mismatched");
1640
1641 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001642}
1643
Keith Davisa57eccb2019-06-14 17:33:22 +01001644bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001645 const TensorInfo& output,
1646 const SpaceToDepthDescriptor& descriptor,
1647 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001648{
1649
1650 ignore_unused(descriptor);
1651 bool supported = true;
1652
Matthew Jackson9bff1442019-09-12 09:08:23 +01001653 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001654 {
1655 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001656 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001657 DataType::QAsymmU8,
1658 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001659 };
1660
1661 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1662 "Reference SpaceToDepth: input type not supported");
1663
1664 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1665 "Reference SpaceToDepth: output type not supported");
1666
1667 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1668 "Reference SpaceToDepth: input and output types are mismatched");
1669
1670 return supported;
1671}
1672
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001673bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1674 const ViewsDescriptor& descriptor,
1675 Optional<std::string&> reasonIfUnsupported) const
1676{
1677 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001678 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001679 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001680 {
1681 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001682 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001683 DataType::QAsymmU8,
1684 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001685 };
1686
1687 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1688 "Reference splitter: input type not supported");
1689
1690 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001691}
1692
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001693bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1694 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1695 const ViewsDescriptor& descriptor,
1696 Optional<std::string&> reasonIfUnsupported) const
1697{
1698 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001699 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001700 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001701 {
1702 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001703 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001704 DataType::QAsymmU8,
1705 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001706 };
1707
1708 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1709 "Reference splitter: output type not supported");
1710 for (const TensorInfo output : outputs)
1711 {
1712 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1713 "Reference splitter: input type not supported");
1714
1715 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1716 "Reference splitter: input and output types mismatched.");
1717 }
1718
1719 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001720}
1721
Matthew Jackson81e601c2019-07-11 12:07:09 +01001722bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1723 const TensorInfo& output,
1724 const StackDescriptor& descriptor,
1725 Optional<std::string&> reasonIfUnsupported) const
1726{
1727 ignore_unused(descriptor);
1728
1729 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001730 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001731 {
1732 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001733 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001734 DataType::QAsymmU8,
1735 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01001736 };
1737
1738 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1739 "Reference stack: output type not supported");
1740 for (const TensorInfo* input : inputs)
1741 {
1742 BOOST_ASSERT(input != nullptr);
1743 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1744 "Reference stack: input type not supported");
1745
1746 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1747 "Reference stack: input and output types mismatched.");
1748 }
1749
1750 return supported;
1751}
1752
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001753bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1754 const TensorInfo& output,
1755 const StridedSliceDescriptor& descriptor,
1756 Optional<std::string&> reasonIfUnsupported) const
1757{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001758 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001759 bool supported = true;
1760
1761 std::array<DataType,3> supportedTypes =
1762 {
1763 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001764 DataType::QAsymmU8,
1765 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001766 };
1767
1768 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1769 "Reference StridedSlice: input type not supported");
1770
1771 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1772 "Reference StridedSlice: output type not supported");
1773
1774 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1775 "Reference StridedSlice: input and output types are mismatched");
1776
1777 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001778}
1779
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001780bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1781 const TensorInfo& input1,
1782 const TensorInfo& output,
1783 Optional<std::string&> reasonIfUnsupported) const
1784{
Sadik Armagan2999a022019-04-09 14:20:12 +01001785 bool supported = true;
1786
Matthew Jackson9bff1442019-09-12 09:08:23 +01001787 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001788 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001789 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001790 DataType::QAsymmU8,
1791 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001792 };
1793
1794 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1795 "Reference subtraction: input 0 is not a supported type.");
1796
1797 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1798 "Reference subtraction: input 1 is not a supported type.");
1799
1800 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1801 "Reference subtraction: output is not a supported type.");
1802
1803 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1804 "Reference subtraction: input 0 and Input 1 types are mismatched");
1805
1806 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1807 "Reference subtraction: input and output types are mismatched");
1808
1809 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1810 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1811
1812 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001813}
1814
Matteo Martincighab9e5252019-06-13 17:27:46 +01001815bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1816 const TensorInfo& alpha,
1817 const TensorInfo& output,
1818 Optional<std::string&> reasonIfUnsupported) const
1819{
1820 bool supported = true;
1821
Matthew Jackson9bff1442019-09-12 09:08:23 +01001822 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001823 {
1824 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001825 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001826 DataType::QAsymmU8,
1827 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01001828 };
1829
1830 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1831 "PReLU: input is not a supported type.");
1832
1833 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1834 "PReLU: alpha is not a supported type.");
1835
1836 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1837 "PReLU: output is not a supported type.");
1838
1839 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1840 "PReLU: input, alpha and output types are mismatched");
1841
1842 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1843 "PReLU: shapes are not suitable for implicit broadcast");
1844
1845 return supported;
1846}
1847
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001848bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1849 const TensorInfo& output,
1850 const TransposeConvolution2dDescriptor& descriptor,
1851 const TensorInfo& weights,
1852 const Optional<TensorInfo>& biases,
1853 Optional<std::string&> reasonIfUnsupported) const
1854{
Derek Lamberti901ea112019-12-10 22:07:09 +00001855 boost::ignore_unused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001856 bool supported = true;
1857
Matthew Jackson252df3a2019-09-11 09:19:18 +01001858 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001859 {
1860 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001861 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001862 DataType::QAsymmU8,
1863 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001864 };
1865
1866 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1867 "Reference TransposeConvolution2d: input is not a supported type.");
1868
1869 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1870 "Reference TransposeConvolution2d: output is not a supported type.");
1871
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001872 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1873 "Reference TransposeConvolution2d: input and output types mismatched.");
1874
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001875
1876 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +00001877 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001878 {
Derek Lambertid466a542020-01-22 15:37:29 +00001879 ARMNN_NO_DEPRECATE_WARN_BEGIN
1880 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001881 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001882 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00001883 DataType::QSymmS8,
1884 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001885 };
Derek Lambertid466a542020-01-22 15:37:29 +00001886 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001887
1888 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1889 "Reference TransposeConvolution2d: weights type not supported for "
1890 "quantized input.");
1891 }
1892 else
1893 {
1894 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1895 "Reference TransposeConvolution2d: weights is not a supported type.");
1896
1897 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1898 "Reference TransposeConvolution2d: input and weights types mismatched.");
1899 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001900
1901 if (biases.has_value())
1902 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001903 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001904 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001905 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001906 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001907 DataType::Signed32
1908 };
1909 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1910 "Reference TransposeConvolution2d: biases is not a supported type.");
1911 }
1912
1913 return supported;
1914}
1915
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001916bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
1917 const TensorInfo& output,
1918 const TransposeDescriptor& descriptor,
1919 Optional<std::string&> reasonIfUnsupported) const
1920{
1921 ignore_unused(descriptor);
1922 bool supported = true;
1923
1924 // Define supported output and inputs types.
1925 std::array<DataType, 4> supportedTypes =
1926 {
1927 DataType::Float32,
1928 DataType::Float16,
1929 DataType::QAsymmU8,
1930 DataType::QSymmS16
1931 };
1932
1933 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1934 "Reference transpose: input is not a supported type.");
1935
1936 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1937 "Reference transpose: output is not a supported type.");
1938
1939 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1940 "Reference transpose: input and output types are mismatched.");
1941
1942 return supported;
1943}
1944
arovir011c7c81b2018-10-08 11:34:28 +01001945} // namespace armnn