blob: bba83e23d4edacda54d96ea038265d37673f5974 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
Matteo Martincighe011d202019-11-28 11:35:47 +000012#include <LayerSupportCommon.hpp>
13
Derek Lambertif674aa02019-08-01 15:56:25 +010014#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000015
Matteo Martincighe011d202019-11-28 11:35:47 +000016#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000020#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
James Conroy4d1ff582019-06-10 17:06:39 +010049namespace
50{
51
52std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
53 unsigned int actual,
54 std::string& layerStr,
55 std::string& tensorName)
56{
57 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
58 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
59
60 return errorMsg;
61}
62
63} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000064
Sadik Armagan9199e582019-09-05 17:35:31 +010065bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
66 Optional<std::string&> reasonIfUnsupported) const
67{
josh minor4a3c6102020-01-06 16:40:46 -060068 return IsElementwiseUnarySupported(input,
69 output,
70 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
71 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010072}
73
arovir011c7c81b2018-10-08 11:34:28 +010074bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
75 const TensorInfo& output,
76 const ActivationDescriptor& descriptor,
77 Optional<std::string&> reasonIfUnsupported) const
78{
Derek Lamberti50db4e82019-03-13 14:16:15 +000079 bool supported = true;
80
81 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000082 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +000083 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010084 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000085 DataType::QSymmS8,
86 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000087 DataType::QAsymmU8,
88 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000089 };
90
91 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
92 "Reference activation: input type not supported.");
93
94 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
95 "Reference activation: output type not supported.");
96
97 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
98 "Reference activation: input and output types mismatched.");
99
100 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
101 "Reference activation: input and output shapes are of different rank.");
102
103
104 struct ActivationFunctionSupported : public Rule
105 {
106 ActivationFunctionSupported(const ActivationDescriptor& desc)
107 {
108 switch(desc.m_Function)
109 {
110 case ActivationFunction::Abs:
111 case ActivationFunction::BoundedReLu:
112 case ActivationFunction::LeakyReLu:
113 case ActivationFunction::Linear:
114 case ActivationFunction::ReLu:
115 case ActivationFunction::Sigmoid:
116 case ActivationFunction::SoftReLu:
117 case ActivationFunction::Sqrt:
118 case ActivationFunction::Square:
119 case ActivationFunction::TanH:
120 {
121 m_Res = true;
122 break;
123 }
124 default:
125 {
126 m_Res = false;
127 break;
128 }
129 }
130 }
131 };
132
133 // Function is supported
134 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
135 "Reference activation: function not supported.");
136
137 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100138}
139
140bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
141 const TensorInfo& input1,
142 const TensorInfo& output,
143 Optional<std::string&> reasonIfUnsupported) const
144{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000145 bool supported = true;
146
Keith Davis0c2eeac2020-02-11 16:51:50 +0000147 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000148 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100149 DataType::Float16,
Keith Davis5204aa82020-01-27 15:24:59 +0000150 DataType::QSymmS8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000151 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000152 DataType::QAsymmU8,
153 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000154 };
155
156 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
157 "Reference addition: input 0 is not a supported type.");
158
159 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
160 "Reference addition: input 1 is not a supported type.");
161
162 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
163 "Reference addition: output is not a supported type.");
164
165 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
166 "Reference addition: input 0 and Input 1 types are mismatched");
167
168 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
169 "Reference addition: input and output types are mismatched");
170
171 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
172 "Reference addition: shapes are not suitable for implicit broadcast.");
173
174 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100175}
176
Nikhil Raj68c2c902019-09-19 11:21:11 +0100177bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
178 const armnn::ArgMinMaxDescriptor &descriptor,
179 armnn::Optional<std::string &> reasonIfUnsupported) const
180{
181 ignore_unused(descriptor);
182
Francis Murtagh1939df52019-11-13 15:21:09 +0000183 std::array<DataType, 4> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100184 {
185 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000186 DataType::QAsymmU8,
187 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000188 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100189 };
190
191 bool supported = true;
192
193 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
194 "Reference ArgMinMax: input is not a supported type.");
195 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
196 "Reference ArgMinMax: output type not supported");
197
198 return supported;
199}
200
arovir011c7c81b2018-10-08 11:34:28 +0100201bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
202 const TensorInfo& output,
203 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100204 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100205 const TensorInfo& beta,
206 const TensorInfo& gamma,
207 const BatchNormalizationDescriptor& descriptor,
208 Optional<std::string&> reasonIfUnsupported) const
209{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100210 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100211
Matthew Jackson9bff1442019-09-12 09:08:23 +0100212 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100213 {
214 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100215 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000216 DataType::QAsymmU8,
217 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100218 };
219
220 bool supported = true;
221
222 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
223 "Reference batch normalization: input is not a supported type.");
224
225 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
226 "Reference batch normalization: output is not a supported type.");
227
228 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
229 "Reference batch normalization: input and output types are mismatched");
230
231 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
232 "Reference batch normalization: mean is not a supported type.");
233
234 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
235 "Reference batch normalization: variance is not a supported type.");
236
237 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
238 "Reference batch normalization: beta is not a supported type.");
239
240 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
241 "Reference batch normalization: gamma is not a supported type.");
242
243 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100244}
245
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000246bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
247 const TensorInfo& output,
248 const BatchToSpaceNdDescriptor& descriptor,
249 Optional<std::string&> reasonIfUnsupported) const
250{
251 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100252
253 bool supported = true;
254
255 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
256 std::string inputTensorStr = "input";
257 std::string outputTensorStr = "output";
258
259 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100260 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100261 {
262 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100263 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000264 DataType::QAsymmU8,
265 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100266 };
267
268 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
269 "Reference BatchToSpaceNd: input type not supported.");
270
271 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
272 "Reference BatchToSpaceNd: output type not supported.");
273
274 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
275 "Reference BatchToSpaceNd: input and output types mismatched.");
276
277 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
278 reasonIfUnsupported,
279 CreateIncorrectDimensionsErrorMsg(4,
280 output.GetNumDimensions(),
281 batchToSpaceNdLayerStr,
282 outputTensorStr).data());
283
284 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
285 reasonIfUnsupported,
286 CreateIncorrectDimensionsErrorMsg(4,
287 input.GetNumDimensions(),
288 batchToSpaceNdLayerStr,
289 inputTensorStr).data());
290
291 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000292}
293
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100294bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
295 const TensorInfo& input1,
296 const TensorInfo& output,
297 const ComparisonDescriptor& descriptor,
298 Optional<std::string&> reasonIfUnsupported) const
299{
300 boost::ignore_unused(descriptor);
301
302 std::array<DataType, 4> supportedInputTypes =
303 {
304 DataType::Float32,
305 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000306 DataType::QAsymmU8,
307 DataType::QSymmS16
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100308 };
309
310 bool supported = true;
311 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
312 "Reference comparison: input 0 is not a supported type");
313
314 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
315 "Reference comparison: input 0 and Input 1 types are mismatched");
316
317 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
318 "Reference comparison: output is not of type Boolean");
319
320 return supported;
321}
322
Jim Flynn906f9462019-05-10 13:55:21 +0100323bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
324 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100325 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100326 Optional<std::string&> reasonIfUnsupported) const
327{
Jim Flynne242f2d2019-05-22 14:24:13 +0100328 ignore_unused(descriptor);
329
330 bool supported = true;
Keith Davis5204aa82020-01-27 15:24:59 +0000331 std::array<DataType,5> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100332 {
333 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100334 DataType::Float16,
Keith Davis5204aa82020-01-27 15:24:59 +0000335 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000336 DataType::QAsymmU8,
337 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100338 };
339
340 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
341 "Reference concatenation: output type not supported");
342 for (const TensorInfo* input : inputs)
343 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100344 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100345 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
346 "Reference concatenation: input type not supported");
347
348 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
349 "Reference concatenation: input and output types mismatched.");
350 }
351
352 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100353}
354
arovir011c7c81b2018-10-08 11:34:28 +0100355bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
356 Optional<std::string&> reasonIfUnsupported) const
357{
Keith Davis5204aa82020-01-27 15:24:59 +0000358 std::array<DataType,5> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100359 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100360 DataType::Float32,
361 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000362 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000363 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000364 DataType::QSymmS16
Nina Drozd58ef2c62019-05-16 12:09:18 +0100365 };
366
367 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
368 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100369}
370
371bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
372 const TensorInfo& output,
373 Optional<std::string&> reasonIfUnsupported) const
374{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100375 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
376 input.GetDataType(),
377 &TrueFunc<>,
378 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000379 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000380 &FalseFuncI32<>,
381 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100382 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
383 output.GetDataType(),
384 &FalseOutputFuncF16<>,
385 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000386 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000387 &FalseFuncI32<>,
388 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100389}
390
391bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
392 const TensorInfo& output,
393 Optional<std::string&> reasonIfUnsupported) const
394{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100395 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
396 input.GetDataType(),
397 &FalseInputFuncF16<>,
398 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000399 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000400 &FalseFuncI32<>,
401 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100402 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
403 output.GetDataType(),
404 &TrueFunc<>,
405 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000406 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000407 &FalseFuncI32<>,
408 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100409}
410
411bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
412 const TensorInfo& output,
413 const Convolution2dDescriptor& descriptor,
414 const TensorInfo& weights,
415 const Optional<TensorInfo>& biases,
416 Optional<std::string&> reasonIfUnsupported) const
417{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100418 bool supported = true;
419
420 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000421 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000422 {
423 DataType::Float32,
424 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000425 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000426 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000427 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000428 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100429 };
430
431 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000432 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100433
434 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000435 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100436
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100437 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000438 "Reference Convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100439
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000440 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000441 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000442 {
Derek Lambertid466a542020-01-22 15:37:29 +0000443 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000444 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000445 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000446 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000447 DataType::QSymmS8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000448 DataType::QAsymmS8,
Derek Lambertid466a542020-01-22 15:37:29 +0000449 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000450 };
Derek Lambertid466a542020-01-22 15:37:29 +0000451 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000452
453 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000454 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000455 }
456 else
457 {
458 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000459 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000460
461 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000462 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000463 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100464
465 if (biases.has_value())
466 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000467 std::array<DataType,3> biasesSupportedTypes =
468 {
469 DataType::Float32,
470 DataType::Float16,
471 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100472 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000473
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100474 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000475 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100476 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100477 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100478
479 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100480}
481
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000482bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
483 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000484 Optional<std::string&> reasonIfUnsupported) const
485{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100486 bool supported = true;
487
Keith Davis0c2eeac2020-02-11 16:51:50 +0000488 std::array<DataType, 7> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100489 {
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000490 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100491 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000492 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000493 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000494 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000495 DataType::QSymmS16,
496 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100497 };
498
499 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000500 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100501
502 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000503 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100504
505 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000506 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100507
508 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000509}
510
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100511bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
512 const TensorInfo& output,
513 const DepthToSpaceDescriptor& descriptor,
514 Optional<std::string&> reasonIfUnsupported) const
515{
516 ignore_unused(descriptor);
517 bool supported = true;
518
519 std::array<DataType,4> supportedTypes =
520 {
521 DataType::Float32,
522 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000523 DataType::QAsymmU8,
524 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100525 };
526
527 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
528 "Reference DepthToSpace: input type not supported");
529
530 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
531 "Reference DepthToSpace: output type not supported");
532
533 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
534 "Reference DepthToSpace: input and output types are mismatched");
535
536 return supported;
537}
538
arovir011c7c81b2018-10-08 11:34:28 +0100539bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
540 const TensorInfo& output,
541 const DepthwiseConvolution2dDescriptor& descriptor,
542 const TensorInfo& weights,
543 const Optional<TensorInfo>& biases,
544 Optional<std::string&> reasonIfUnsupported) const
545{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100546 bool supported = true;
547
548 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000549 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100550 {
551 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100552 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000553 DataType::QSymmS8,
554 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000555 DataType::QAsymmU8,
556 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100557 };
558
559 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
560 "Reference DepthwiseConvolution2d: input is not a supported type.");
561
562 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
563 "Reference DepthwiseConvolution2d: output is not a supported type.");
564
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100565 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
566 "Reference DepthwiseConvolution2d: input and output types mismatched.");
567
Derek Lambertid466a542020-01-22 15:37:29 +0000568 ARMNN_NO_DEPRECATE_WARN_BEGIN
569 std::array<DataType, 3> supportedWeightTypes =
570 {
571 DataType::QAsymmU8,
572 DataType::QSymmS8,
573 DataType::QuantizedSymm8PerAxis // deprecated
574 };
575 ARMNN_NO_DEPRECATE_WARN_END
576
Teresa Charlind8df0262019-11-11 12:28:15 +0000577 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000578 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000579 {
Teresa Charlind8df0262019-11-11 12:28:15 +0000580
581 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
582 "Reference convolution2d: weights type not supported for quantized input.");
583 }
584 else
585 {
586 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
587 "Reference DepthwiseConvolution2d: weights is not a supported type.");
588
589 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
590 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
591 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100592
593 if (biases.has_value())
594 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100595 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100596 {
597 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100598 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100599 DataType::Signed32
600 };
601 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
602 "Reference DepthwiseConvolution2d: biases is not a supported type.");
603 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100604 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100605
606 return supported;
607
arovir011c7c81b2018-10-08 11:34:28 +0100608}
609
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000610bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
611 const TensorInfo& output,
612 Optional<std::string&> reasonIfUnsupported) const
613{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100614 bool supported = true;
615
Ryan OShea9add1202020-02-07 10:06:33 +0000616 std::array<DataType,4> supportedInputTypes = {
617 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000618 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000619 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000620 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100621 };
622
623 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000624 "Reference for Dequantize layer: input type not supported.");
625
626 supported &= CheckSupportRule( TypeNotPerAxisQuantized(input), reasonIfUnsupported,
627 "Reference for Dequantize layer: per-axis quantized input not support .");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100628
Derek Lambertid466a542020-01-22 15:37:29 +0000629 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
630 "Reference dequantize: per-axis quantized input not support .");
631
Jan Eilersf7107932019-11-01 11:09:36 +0000632 std::array<DataType,2> supportedOutputTypes = {
633 DataType::Float32,
634 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100635 };
636
637 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000638 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100639
640 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000641 "Reference for Dequantize layer: input/output shapes have different num total "
642 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100643
644 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000645}
646
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000647bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
648 const TensorInfo& scores,
649 const TensorInfo& anchors,
650 const TensorInfo& detectionBoxes,
651 const TensorInfo& detectionClasses,
652 const TensorInfo& detectionScores,
653 const TensorInfo& numDetections,
654 const DetectionPostProcessDescriptor& descriptor,
655 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000656{
Derek Lamberti901ea112019-12-10 22:07:09 +0000657 boost::ignore_unused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
658
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100659 bool supported = true;
660
Mike Kelly4992c342019-08-14 11:33:11 +0100661 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100662 {
663 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000664 DataType::QAsymmU8,
665 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100666 };
667
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000668 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100669 "Reference DetectionPostProcess: input 0 is not a supported type.");
670
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000671 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100672 "Reference DetectionPostProcess: input 1 is not a supported type.");
673
674 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000675}
676
Pablo Tellof0bd6832019-04-26 17:58:13 +0100677bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
678 const TensorInfo& output,
679 const DepthwiseConvolution2dDescriptor& descriptor,
680 const TensorInfo& weights,
681 const Optional<TensorInfo>& biases,
682 Optional<std::string&> reasonIfUnsupported) const
683{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100684 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100685}
686
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100687bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100688 const TensorInfo& input1,
689 const TensorInfo& output,
690 Optional<std::string&> reasonIfUnsupported) const
691{
Sadik Armagan2999a022019-04-09 14:20:12 +0100692 bool supported = true;
693
Matthew Jackson9bff1442019-09-12 09:08:23 +0100694 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100695 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100696 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000697 DataType::QAsymmU8,
698 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +0100699 };
700
701 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
702 "Reference division: input 0 is not a supported type.");
703
704 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
705 "Reference division: input 1 is not a supported type.");
706
707 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
708 "Reference division: output is not a supported type.");
709
710 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
711 "Reference division: input 0 and Input 1 types are mismatched");
712
713 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
714 "Reference division: input and output types are mismatched");
715
716 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
717 "Reference division: shapes are not suitable for implicit broadcast.");
718
719 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100720}
721
josh minor4a3c6102020-01-06 16:40:46 -0600722bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
723 const TensorInfo& output,
724 const ElementwiseUnaryDescriptor& descriptor,
725 Optional<std::string&> reasonIfUnsupported) const
726{
727 boost::ignore_unused(descriptor);
728
729 std::array<DataType, 4> supportedTypes =
730 {
731 DataType::Float32,
732 DataType::Float16,
733 DataType::QAsymmU8,
734 DataType::QSymmS16
735 };
736
737 bool supported = true;
738
739 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
740 "Reference elementwise unary: input type not supported");
741
742 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
743 "Reference elementwise unary: output type not supported");
744
745 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
746 "Reference elementwise unary: input and output types not matching");
747
748 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
749 "Reference elementwise unary: input and output shapes"
750 "have different number of total elements");
751
752 return supported;
753}
754
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000755bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
756 const TensorInfo& input1,
757 const TensorInfo& output,
758 Optional<std::string&> reasonIfUnsupported) const
759{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100760 return IsComparisonSupported(input0,
761 input1,
762 output,
763 ComparisonDescriptor(ComparisonOperation::Equal),
764 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000765}
766
arovir011c7c81b2018-10-08 11:34:28 +0100767bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
768 const FakeQuantizationDescriptor& descriptor,
769 Optional<std::string&> reasonIfUnsupported) const
770{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100771 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100772 bool supported = true;
773
774 std::array<DataType,1> supportedTypes =
775 {
776 DataType::Float32
777 };
778
779 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
780 "Reference fake quantization: input type not supported.");
781
782 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100783}
784
785bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
786 const TensorInfo& output,
787 Optional<std::string&> reasonIfUnsupported) const
788{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100789 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100790 bool supported = true;
791
Matthew Jackson9bff1442019-09-12 09:08:23 +0100792 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100793 {
James Conroyb40d7102019-06-04 12:32:09 +0100794 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100795 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000796 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +0100797 };
798
799 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
800 "Reference Floor: input type not supported.");
801
802 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
803 "Reference Floor: output type not supported.");
804
805 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100806}
807
808bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
809 const TensorInfo& output,
810 const TensorInfo& weights,
811 const TensorInfo& biases,
812 const FullyConnectedDescriptor& descriptor,
813 Optional<std::string&> reasonIfUnsupported) const
814{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100815 bool supported = true;
816
817 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100818 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100819 {
820 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100821 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000822 DataType::QAsymmU8,
823 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100824 };
825
826 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
827 "Reference Fully Connected: input type not supported.");
828
829 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
830 "Reference Fully Connected: output type not supported.");
831
832 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
833 "Reference Fully Connected: input and output types mismatched.");
834
835 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
836 "Reference Fully Connected: weights type not supported.");
837
838 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
839 "Reference Fully Connected: input and weight types mismatched.");
840
841 if (descriptor.m_BiasEnabled)
842 {
843 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100844 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100845 supportedBiasTypes =
846 {
847 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100848 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100849 DataType::Signed32
850 };
851
852 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
853 "Reference Fully Connected: bias type not supported.");
854
855 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
856 "Reference Fully Connected: bias and weight types mismatch.");
857
858 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
859 "Reference Fully Connected: bias type inferred from weights is incompatible.");
860
861 }
862
863 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100864}
865
narpra014951d842019-01-18 16:53:53 +0000866bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
867 const armnn::TensorInfo& input1,
868 const armnn::TensorInfo& output,
869 armnn::Optional<std::string&> reasonIfUnsupported) const
870{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100871 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100872 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100873 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100874 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100875 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000876 DataType::QAsymmU8,
877 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100878 };
879
880 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
881 "Reference Gather: input type not supported");
882
883 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
884 "Reference Gather: output type not supported");
885
886 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
887 "Reference Gather: indices (input1) type not supported");
888
889 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
890 "Reference Gather: input and output types not matching");
891
892 return supported;
narpra014951d842019-01-18 16:53:53 +0000893}
894
FrancisMurtagh878f0232018-12-19 10:56:15 +0000895bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
896 const TensorInfo& input1,
897 const TensorInfo& output,
898 Optional<std::string&> reasonIfUnsupported) const
899{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100900 return IsComparisonSupported(input0,
901 input1,
902 output,
903 ComparisonDescriptor(ComparisonOperation::Greater),
904 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000905}
906
Derek Lamberti901ea112019-12-10 22:07:09 +0000907bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
908 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +0100909{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100910 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100911}
912
Kevin May09ca49c2019-10-09 12:37:34 +0100913bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
914 const TensorInfo& output,
915 const InstanceNormalizationDescriptor& descriptor,
916 Optional<std::string&> reasonIfUnsupported) const
917{
918 ignore_unused(descriptor);
919 // Define supported types
920 std::array<DataType, 4> supportedTypes =
921 {
922 DataType::Float32,
923 DataType::Float16
924 };
925
926 bool supported = true;
927
928 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
929 "Reference Instance Normalization: input type not supported.");
930
931 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
932 "Reference Instance Normalization: output type not supported.");
933
934 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
935 "Reference Instance Normalization: input and output types mismatched.");
936
937 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
938 "Reference Instance Normalization: input and output shapes have different "
939 "num total elements.");
940
941 return supported;
942}
943
arovir011c7c81b2018-10-08 11:34:28 +0100944bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
945 const TensorInfo& output,
946 const L2NormalizationDescriptor& descriptor,
947 Optional<std::string&> reasonIfUnsupported) const
948{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100949 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100950 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100951 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100952 {
953 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100954 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000955 DataType::QAsymmU8,
956 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100957 };
958
959 bool supported = true;
960
961 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
962 "Reference L2normalization: input type not supported.");
963
964 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
965 "Reference L2normalization: output type not supported.");
966
967 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
968 "Reference L2normalization: input and output types mismatched.");
969
970 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
971 "Reference L2normalization: input and output shapes have different "
972 "num total elements.");
973
974 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100975}
976
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100977bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
978 const TensorInfo& output,
979 const LogSoftmaxDescriptor& descriptor,
980 Optional<std::string&> reasonIfUnsupported) const
981{
982 ignore_unused(descriptor);
983
984 std::array<DataType, 2> supportedTypes =
985 {
986 DataType::Float32,
987 DataType::Float16
988 };
989
990 bool supported = true;
991 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
992 "Reference LogSoftmax: input type not supported");
993
994 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
995 "Reference LogSoftmax: output type not supported");
996
997 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
998 "Reference LogSoftmax: input and output types do not match");
999
1000 return supported;
1001}
1002
arovir011c7c81b2018-10-08 11:34:28 +01001003bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1004 const TensorInfo& outputStateIn,
1005 const TensorInfo& cellStateIn,
1006 const TensorInfo& scratchBuffer,
1007 const TensorInfo& outputStateOut,
1008 const TensorInfo& cellStateOut,
1009 const TensorInfo& output,
1010 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001011 const LstmInputParamsInfo& paramsInfo,
1012 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001013{
telsoa01c577f2c2018-08-31 09:22:23 +01001014 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +01001015 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001016
1017 bool supported = true;
1018
1019 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001020 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001021 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001022 };
1023
Jan Eilersd01a83c2019-07-03 18:20:40 +01001024 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001025 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1026 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001027 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1028 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001029 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1030 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001031 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1032 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001033 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1034 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001035 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1036 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001037 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1038 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001039 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001040 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001041 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001042 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001043 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001044 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001045 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001046 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001047 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001048 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001049 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001050 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001051 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001052 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001053 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001054 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001055 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001056 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001057 "Reference Lstm: input and OutputGateBias types are mismatched");
1058 if (!descriptor.m_CifgEnabled)
1059 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001060 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001061 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001062 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001063 reasonIfUnsupported,
1064 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001065 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001066 "Reference Lstm: input and InputGateBias types are mismatched");
1067 if (descriptor.m_PeepholeEnabled)
1068 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001069 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001070 reasonIfUnsupported,
1071 "Reference Lstm: input and CellToInputWeights types are mismatched");
1072 }
1073 }
1074 if (descriptor.m_PeepholeEnabled)
1075 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001076 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001077 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001078 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001079 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1080 }
1081 if (descriptor.m_ProjectionEnabled)
1082 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001083 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001084 "Reference Lstm: input and mProjectionWeights types are mismatched");
1085 if (paramsInfo.m_ProjectionBias != nullptr)
1086 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001087 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001088 "Reference Lstm: input and ProjectionBias types are mismatched");
1089 }
1090 }
1091 if (descriptor.m_LayerNormEnabled)
1092 {
1093 if (!descriptor.m_CifgEnabled)
1094 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001095 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001096 reasonIfUnsupported,
1097 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1098 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001099 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001100 reasonIfUnsupported,
1101 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001102 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001103 reasonIfUnsupported,
1104 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001105 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001106 reasonIfUnsupported,
1107 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1108 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001109
1110 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001111}
1112
saoste012df12b32018-11-28 16:57:20 +00001113bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1114 const TensorInfo& input1,
1115 const TensorInfo& output,
1116 Optional<std::string&> reasonIfUnsupported) const
1117{
Sadik Armagan2999a022019-04-09 14:20:12 +01001118 bool supported = true;
1119
Keith Davis5204aa82020-01-27 15:24:59 +00001120 std::array<DataType,5> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001121 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001122 DataType::Float16,
Keith Davis5204aa82020-01-27 15:24:59 +00001123 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001124 DataType::QAsymmU8,
1125 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001126 };
1127
1128 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1129 "Reference maximum: input 0 is not a supported type.");
1130
1131 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1132 "Reference maximum: input 1 is not a supported type.");
1133
1134 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1135 "Reference maximum: output is not a supported type.");
1136
1137 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1138 "Reference maximum: input 0 and Input 1 types are mismatched");
1139
1140 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1141 "Reference maximum: input and output types are mismatched");
1142
1143 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1144 "Reference maximum: shapes are not suitable for implicit broadcast.");
1145
1146 return supported;
saoste012df12b32018-11-28 16:57:20 +00001147}
1148
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001149bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1150 const TensorInfo& output,
1151 const MeanDescriptor& descriptor,
1152 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001153{
James Conroy4d1ff582019-06-10 17:06:39 +01001154 bool supported = true;
1155 std::string meanLayerStr = "Mean";
1156 std::string outputTensorStr = "output";
1157
Matthew Jackson252df3a2019-09-11 09:19:18 +01001158 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001159 {
1160 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001161 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001162 DataType::QAsymmU8,
1163 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001164 };
1165
1166 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1167 "Reference Mean: input type not supported.");
1168
1169 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1170 "Reference Mean: input and output types are mismatched");
1171
1172 if (descriptor.m_KeepDims)
1173 {
1174 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1175 reasonIfUnsupported,
1176 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1177 output.GetNumDimensions(),
1178 meanLayerStr, outputTensorStr).data());
1179 }
1180 else if (descriptor.m_Axis.empty())
1181 {
1182 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1183 reasonIfUnsupported,
1184 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1185 meanLayerStr, outputTensorStr).data());
1186 }
1187 else
1188 {
1189 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1190
1191 if (outputDim > 0)
1192 {
1193 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1194 reasonIfUnsupported,
1195 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1196 meanLayerStr, outputTensorStr).data());
1197 }
1198 else
1199 {
1200 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1201 reasonIfUnsupported,
1202 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1203 meanLayerStr, outputTensorStr).data());
1204 }
1205 }
1206
1207 return supported;
narpra0132b90462018-09-13 11:07:48 +01001208}
1209
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001210bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001211 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001212 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001213 Optional<std::string&> reasonIfUnsupported) const
1214{
Jim Flynne242f2d2019-05-22 14:24:13 +01001215 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001216}
1217
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001218bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1219 const TensorInfo &output,
1220 Optional<std::string &> reasonIfUnsupported) const
1221{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001222 bool supported = true;
1223
1224 std::array<DataType,5> supportedTypes =
1225 {
1226 DataType::Float32,
1227 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001228 DataType::QAsymmU8,
1229 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001230 DataType::Boolean
1231 };
1232
1233 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1234 "Reference MemCopy: input type not supported");
1235
1236 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1237 "Reference MemCopy: output type not supported");
1238
1239 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1240 "Reference MemCopy: input and output types are mismatched");
1241
1242 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001243}
1244
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001245bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1246 const TensorInfo& input1,
1247 const TensorInfo& output,
1248 Optional<std::string&> reasonIfUnsupported) const
1249{
Sadik Armagan2999a022019-04-09 14:20:12 +01001250 bool supported = true;
1251
Matthew Jackson9bff1442019-09-12 09:08:23 +01001252 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001253 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001254 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001255 DataType::QAsymmU8,
1256 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001257 };
1258
1259 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1260 "Reference minimum: input 0 is not a supported type.");
1261
1262 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1263 "Reference minimum: input 1 is not a supported type.");
1264
1265 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1266 "Reference minimum: output is not a supported type.");
1267
1268 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1269 "Reference minimum: input 0 and Input 1 types are mismatched");
1270
1271 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1272 "Reference minimum: input and output types are mismatched");
1273
1274 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1275 "Reference minimum: shapes are not suitable for implicit broadcast.");
1276
1277 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001278}
1279
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001280bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1281 const TensorInfo& input1,
1282 const TensorInfo& output,
1283 Optional<std::string&> reasonIfUnsupported) const
1284{
Sadik Armagan2999a022019-04-09 14:20:12 +01001285 bool supported = true;
1286
Keith Davis5204aa82020-01-27 15:24:59 +00001287 std::array<DataType,5> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001288 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001289 DataType::Float16,
Keith Davis5204aa82020-01-27 15:24:59 +00001290 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001291 DataType::QAsymmU8,
1292 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001293 };
1294
1295 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1296 "Reference multiplication: input 0 is not a supported type.");
1297
1298 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1299 "Reference multiplication: input 1 is not a supported type.");
1300
1301 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1302 "Reference multiplication: output is not a supported type.");
1303
1304 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1305 "Reference multiplication: input 0 and Input 1 types are mismatched");
1306
1307 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1308 "Reference multiplication: input and output types are mismatched");
1309
1310 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1311 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1312
1313 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001314}
1315
1316bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1317 const TensorInfo& output,
1318 const NormalizationDescriptor& descriptor,
1319 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001320{
Nina Drozd661dfa72018-10-02 11:14:17 +01001321 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001322
1323 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001324 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001325 {
1326 DataType::Float16,
1327 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001328 DataType::QAsymmU8,
1329 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001330 };
1331
1332 bool supported = true;
1333
1334 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1335 "Reference normalization: input type not supported.");
1336
1337 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1338 "Reference normalization: output type not supported.");
1339
1340 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1341 "Reference normalization: input and output shapes have different "
1342 "num total elements.");
1343
1344 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001345}
1346
Derek Lamberti901ea112019-12-10 22:07:09 +00001347bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1348 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001349{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001350 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001351}
1352
1353bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1354 const TensorInfo& output,
1355 const PadDescriptor& descriptor,
1356 Optional<std::string&> reasonIfUnsupported) const
1357{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001358 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001359 bool supported = true;
1360
1361 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001362 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001363 {
1364 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001365 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001366 DataType::QAsymmU8,
1367 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001368 };
1369
1370 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1371 "Reference pad: input is not a supported type.");
1372
1373 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1374 "Reference pad: output is not a supported type.");
1375
1376 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1377 "Reference pad: input and output types are mismatched.");
1378
1379 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001380}
1381
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001382bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1383 const TensorInfo& output,
1384 const PermuteDescriptor& descriptor,
1385 Optional<std::string&> reasonIfUnsupported) const
1386{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001387 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001388 bool supported = true;
1389
1390 // Define supported output and inputs types.
1391 std::array<DataType,3> supportedTypes =
1392 {
1393 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001394 DataType::QAsymmU8,
1395 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001396 };
1397
1398 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1399 "Reference permute: input is not a supported type.");
1400
1401 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1402 "Reference permute: output is not a supported type.");
1403
1404 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1405 "Reference permute: input and output types are mismatched.");
1406
1407 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001408}
1409
1410bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1411 const TensorInfo& output,
1412 const Pooling2dDescriptor& descriptor,
1413 Optional<std::string&> reasonIfUnsupported) const
1414{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001415 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001416 bool supported = true;
1417
1418 // Define supported output and inputs types.
Keith Davis0c2eeac2020-02-11 16:51:50 +00001419 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001420 {
1421 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001422 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001423 DataType::QSymmS8,
1424 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001425 DataType::QAsymmU8,
1426 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001427 };
1428
1429 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1430 "Reference poolind2d: input is not a supported type.");
1431
1432 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1433 "Reference poolind2d: output is not a supported type.");
1434
1435 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1436 "Reference poolind2d: input and output types are mismatched.");
1437
1438 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001439}
1440
Derek Lamberti5f400d62019-03-25 15:41:58 +00001441bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1442 const TensorInfo& output,
1443 Optional<std::string&> reasonIfUnsupported) const
1444{
1445 bool supported = true;
1446
Finn Williamsfd271062019-12-04 14:27:27 +00001447 // Define supported input types.
Ryan OShea9add1202020-02-07 10:06:33 +00001448 std::array<DataType,6> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00001449 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001450 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001451 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001452 DataType::QAsymmU8,
1453 DataType::QSymmS8,
1454 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001455 };
1456
1457 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1458 "Reference quantize: input type not supported.");
1459
1460 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001461 std::array<DataType,4> supportedOutputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001462 DataType::QAsymmU8,
Ryan OShea9add1202020-02-07 10:06:33 +00001463 DataType::QAsymmS8,
Finn Williamsfd271062019-12-04 14:27:27 +00001464 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001465 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001466 };
1467 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1468 "Reference quantize: output type not supported.");
1469
1470 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1471 "Reference quantize: input and output shapes have different num total elements.");
1472
1473 return supported;
1474}
1475
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001476bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001477 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001478 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001479 Optional<std::string&> reasonIfUnsupported) const
1480{
Kevin Maya023c402019-12-12 17:28:05 +00001481 ignore_unused(output);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001482 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001483 // Define supported output types.
Keith Davis0c2eeac2020-02-11 16:51:50 +00001484 std::array<DataType,7> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001485 {
1486 DataType::Float32,
1487 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001488 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001489 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001490 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001491 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001492 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001493 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001494
Nina Drozd2f2778f2019-05-27 10:37:05 +01001495 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1496 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001497}
1498
1499bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001500 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001501 Optional<std::string&> reasonIfUnsupported) const
1502{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001503 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001504 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001505 {
1506 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001507 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001508 DataType::QAsymmU8,
1509 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001510 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001511
1512 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1513 "Reference ResizeBilinear: input type not supported");
1514
1515 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1516 "Reference ResizeBilinear: output type not supported");
1517
1518 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1519 "Reference ResizeBilinear: input and output types not matching");
1520
1521 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001522}
1523
Teresa Charlin970f43b2019-07-01 13:51:07 +01001524bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1525 const TensorInfo& output,
1526 const ResizeDescriptor& descriptor,
1527 Optional<std::string&> reasonIfUnsupported) const
1528{
Derek Lamberti901ea112019-12-10 22:07:09 +00001529 boost::ignore_unused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001530 bool supported = true;
Keith Davis5204aa82020-01-27 15:24:59 +00001531 std::array<DataType,5> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001532 {
1533 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001534 DataType::Float16,
Keith Davis5204aa82020-01-27 15:24:59 +00001535 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001536 DataType::QAsymmU8,
1537 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001538 };
1539
1540 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1541 "Reference Resize: input type not supported");
1542
1543 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1544 "Reference Resize: output type not supported");
1545
1546 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1547 "Reference Resize: input and output types not matching");
1548
1549 return supported;
1550}
1551
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001552bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1553 const TensorInfo& output,
1554 Optional<std::string&> reasonIfUnsupported) const
1555{
josh minor4a3c6102020-01-06 16:40:46 -06001556 return IsElementwiseUnarySupported(input,
1557 output,
1558 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1559 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001560}
1561
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001562bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1563 const TensorInfo& output,
1564 const SliceDescriptor& descriptor,
1565 Optional<std::string&> reasonIfUnsupported) const
1566{
Derek Lamberti901ea112019-12-10 22:07:09 +00001567 boost::ignore_unused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001568 bool supported = true;
1569
1570 std::array<DataType, 3> supportedTypes =
1571 {
1572 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001573 DataType::QAsymmU8,
1574 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001575 };
1576
1577 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1578 "Reference Slice: input type not supported");
1579
1580 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1581 "Reference Slice: output type not supported");
1582
1583 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1584 "Reference Slice: input and output types are mismatched");
1585
1586 return supported;
1587}
1588
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001589bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1590 const TensorInfo& output,
1591 const SoftmaxDescriptor& descriptor,
1592 Optional<std::string&> reasonIfUnsupported) const
1593{
Derek Lamberti901ea112019-12-10 22:07:09 +00001594 boost::ignore_unused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001595 bool supported = true;
Keith Davis0c2eeac2020-02-11 16:51:50 +00001596 std::array<DataType,6> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001597 {
1598 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001599 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001600 DataType::QSymmS8,
1601 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001602 DataType::QAsymmU8,
1603 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001604 };
1605
1606 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001607 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001608
1609 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001610 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001611
1612 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001613 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001614
1615 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001616}
1617
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001618bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1619 const TensorInfo& output,
1620 const SpaceToBatchNdDescriptor& descriptor,
1621 Optional<std::string&> reasonIfUnsupported) const
1622{
Derek Lamberti901ea112019-12-10 22:07:09 +00001623 boost::ignore_unused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001624 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001625 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001626 {
1627 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001628 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001629 DataType::QAsymmU8,
1630 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001631 };
1632
1633 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1634 "Reference SpaceToBatchNd: input type not supported");
1635
1636 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1637 "Reference SpaceToBatchNd: output type not supported");
1638
1639 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1640 "Reference SpaceToBatchNd: input and output types are mismatched");
1641
1642 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001643}
1644
Keith Davisa57eccb2019-06-14 17:33:22 +01001645bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001646 const TensorInfo& output,
1647 const SpaceToDepthDescriptor& descriptor,
1648 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001649{
1650
1651 ignore_unused(descriptor);
1652 bool supported = true;
1653
Matthew Jackson9bff1442019-09-12 09:08:23 +01001654 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001655 {
1656 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001657 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001658 DataType::QAsymmU8,
1659 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001660 };
1661
1662 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1663 "Reference SpaceToDepth: input type not supported");
1664
1665 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1666 "Reference SpaceToDepth: output type not supported");
1667
1668 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1669 "Reference SpaceToDepth: input and output types are mismatched");
1670
1671 return supported;
1672}
1673
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001674bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1675 const ViewsDescriptor& descriptor,
1676 Optional<std::string&> reasonIfUnsupported) const
1677{
1678 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001679 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001680 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001681 {
1682 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001683 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001684 DataType::QAsymmU8,
1685 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001686 };
1687
1688 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1689 "Reference splitter: input type not supported");
1690
1691 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001692}
1693
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001694bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1695 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1696 const ViewsDescriptor& descriptor,
1697 Optional<std::string&> reasonIfUnsupported) const
1698{
1699 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001700 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001701 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001702 {
1703 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001704 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001705 DataType::QAsymmU8,
1706 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001707 };
1708
1709 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1710 "Reference splitter: output type not supported");
1711 for (const TensorInfo output : outputs)
1712 {
1713 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1714 "Reference splitter: input type not supported");
1715
1716 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1717 "Reference splitter: input and output types mismatched.");
1718 }
1719
1720 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001721}
1722
Matthew Jackson81e601c2019-07-11 12:07:09 +01001723bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1724 const TensorInfo& output,
1725 const StackDescriptor& descriptor,
1726 Optional<std::string&> reasonIfUnsupported) const
1727{
1728 ignore_unused(descriptor);
1729
1730 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001731 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001732 {
1733 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001734 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001735 DataType::QAsymmU8,
1736 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01001737 };
1738
1739 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1740 "Reference stack: output type not supported");
1741 for (const TensorInfo* input : inputs)
1742 {
1743 BOOST_ASSERT(input != nullptr);
1744 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1745 "Reference stack: input type not supported");
1746
1747 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1748 "Reference stack: input and output types mismatched.");
1749 }
1750
1751 return supported;
1752}
1753
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001754bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1755 const TensorInfo& output,
1756 const StridedSliceDescriptor& descriptor,
1757 Optional<std::string&> reasonIfUnsupported) const
1758{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001759 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001760 bool supported = true;
1761
1762 std::array<DataType,3> supportedTypes =
1763 {
1764 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001765 DataType::QAsymmU8,
1766 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001767 };
1768
1769 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1770 "Reference StridedSlice: input type not supported");
1771
1772 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1773 "Reference StridedSlice: output type not supported");
1774
1775 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1776 "Reference StridedSlice: input and output types are mismatched");
1777
1778 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001779}
1780
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001781bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1782 const TensorInfo& input1,
1783 const TensorInfo& output,
1784 Optional<std::string&> reasonIfUnsupported) const
1785{
Sadik Armagan2999a022019-04-09 14:20:12 +01001786 bool supported = true;
1787
Matthew Jackson9bff1442019-09-12 09:08:23 +01001788 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001789 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001790 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001791 DataType::QAsymmU8,
1792 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001793 };
1794
1795 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1796 "Reference subtraction: input 0 is not a supported type.");
1797
1798 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1799 "Reference subtraction: input 1 is not a supported type.");
1800
1801 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1802 "Reference subtraction: output is not a supported type.");
1803
1804 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1805 "Reference subtraction: input 0 and Input 1 types are mismatched");
1806
1807 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1808 "Reference subtraction: input and output types are mismatched");
1809
1810 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1811 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1812
1813 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001814}
1815
Matteo Martincighab9e5252019-06-13 17:27:46 +01001816bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1817 const TensorInfo& alpha,
1818 const TensorInfo& output,
1819 Optional<std::string&> reasonIfUnsupported) const
1820{
1821 bool supported = true;
1822
Matthew Jackson9bff1442019-09-12 09:08:23 +01001823 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001824 {
1825 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001826 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001827 DataType::QAsymmU8,
1828 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01001829 };
1830
1831 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1832 "PReLU: input is not a supported type.");
1833
1834 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1835 "PReLU: alpha is not a supported type.");
1836
1837 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1838 "PReLU: output is not a supported type.");
1839
1840 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1841 "PReLU: input, alpha and output types are mismatched");
1842
1843 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1844 "PReLU: shapes are not suitable for implicit broadcast");
1845
1846 return supported;
1847}
1848
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001849bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1850 const TensorInfo& output,
1851 const TransposeConvolution2dDescriptor& descriptor,
1852 const TensorInfo& weights,
1853 const Optional<TensorInfo>& biases,
1854 Optional<std::string&> reasonIfUnsupported) const
1855{
Derek Lamberti901ea112019-12-10 22:07:09 +00001856 boost::ignore_unused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001857 bool supported = true;
1858
Matthew Jackson252df3a2019-09-11 09:19:18 +01001859 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001860 {
1861 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001862 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001863 DataType::QAsymmU8,
1864 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001865 };
1866
1867 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1868 "Reference TransposeConvolution2d: input is not a supported type.");
1869
1870 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1871 "Reference TransposeConvolution2d: output is not a supported type.");
1872
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001873 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1874 "Reference TransposeConvolution2d: input and output types mismatched.");
1875
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001876
1877 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +00001878 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001879 {
Derek Lambertid466a542020-01-22 15:37:29 +00001880 ARMNN_NO_DEPRECATE_WARN_BEGIN
1881 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001882 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001883 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00001884 DataType::QSymmS8,
1885 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001886 };
Derek Lambertid466a542020-01-22 15:37:29 +00001887 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001888
1889 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1890 "Reference TransposeConvolution2d: weights type not supported for "
1891 "quantized input.");
1892 }
1893 else
1894 {
1895 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1896 "Reference TransposeConvolution2d: weights is not a supported type.");
1897
1898 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1899 "Reference TransposeConvolution2d: input and weights types mismatched.");
1900 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001901
1902 if (biases.has_value())
1903 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001904 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001905 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001906 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001907 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001908 DataType::Signed32
1909 };
1910 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1911 "Reference TransposeConvolution2d: biases is not a supported type.");
1912 }
1913
1914 return supported;
1915}
1916
arovir011c7c81b2018-10-08 11:34:28 +01001917} // namespace armnn