blob: 9f22b9ef0ea00c7e803b912ff7b998cc054ced49 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matteo Martincighe011d202019-11-28 11:35:47 +000013#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010014#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000015
Matteo Martincighe011d202019-11-28 11:35:47 +000016#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <array>
20
telsoa014fcda012018-03-09 14:13:49 +000021using namespace boost;
22
23namespace armnn
24{
25
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010026namespace
27{
28
29template<typename Float32Func, typename Uint8Func, typename ... Params>
30bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
31 DataType dataType,
32 Float32Func floatFuncPtr,
33 Uint8Func uint8FuncPtr,
34 Params&&... params)
35{
36 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
37 dataType,
38 &FalseFunc<Params...>,
39 floatFuncPtr,
40 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000041 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000042 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010043 std::forward<Params>(params)...);
44}
45
46} // anonymous namespace
47
James Conroy4d1ff582019-06-10 17:06:39 +010048namespace
49{
50
51std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
52 unsigned int actual,
53 std::string& layerStr,
54 std::string& tensorName)
55{
56 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
57 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
58
59 return errorMsg;
60}
61
62} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000063
Sadik Armagan9199e582019-09-05 17:35:31 +010064bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
65 Optional<std::string&> reasonIfUnsupported) const
66{
josh minor4a3c6102020-01-06 16:40:46 -060067 return IsElementwiseUnarySupported(input,
68 output,
69 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
70 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010071}
72
arovir011c7c81b2018-10-08 11:34:28 +010073bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
74 const TensorInfo& output,
75 const ActivationDescriptor& descriptor,
76 Optional<std::string&> reasonIfUnsupported) const
77{
Derek Lamberti50db4e82019-03-13 14:16:15 +000078 bool supported = true;
79
80 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000081 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +000082 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +000083 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010084 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000085 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000086 DataType::QAsymmU8,
87 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000088 };
89
90 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
91 "Reference activation: input type not supported.");
92
93 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
94 "Reference activation: output type not supported.");
95
96 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
97 "Reference activation: input and output types mismatched.");
98
99 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
100 "Reference activation: input and output shapes are of different rank.");
101
102
103 struct ActivationFunctionSupported : public Rule
104 {
105 ActivationFunctionSupported(const ActivationDescriptor& desc)
106 {
107 switch(desc.m_Function)
108 {
109 case ActivationFunction::Abs:
110 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000111 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000112 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000113 case ActivationFunction::LeakyReLu:
114 case ActivationFunction::Linear:
115 case ActivationFunction::ReLu:
116 case ActivationFunction::Sigmoid:
117 case ActivationFunction::SoftReLu:
118 case ActivationFunction::Sqrt:
119 case ActivationFunction::Square:
120 case ActivationFunction::TanH:
121 {
122 m_Res = true;
123 break;
124 }
125 default:
126 {
127 m_Res = false;
128 break;
129 }
130 }
131 }
132 };
133
134 // Function is supported
135 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
136 "Reference activation: function not supported.");
137
138 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100139}
140
141bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
142 const TensorInfo& input1,
143 const TensorInfo& output,
144 Optional<std::string&> reasonIfUnsupported) const
145{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000146 bool supported = true;
147
Keith Davis0c2eeac2020-02-11 16:51:50 +0000148 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000149 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000150 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100151 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000152 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000153 DataType::QAsymmU8,
154 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000155 };
156
157 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
158 "Reference addition: input 0 is not a supported type.");
159
160 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
161 "Reference addition: input 1 is not a supported type.");
162
163 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
164 "Reference addition: output is not a supported type.");
165
166 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
167 "Reference addition: input 0 and Input 1 types are mismatched");
168
169 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
170 "Reference addition: input and output types are mismatched");
171
172 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
173 "Reference addition: shapes are not suitable for implicit broadcast.");
174
175 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100176}
177
Nikhil Raj68c2c902019-09-19 11:21:11 +0100178bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
179 const armnn::ArgMinMaxDescriptor &descriptor,
180 armnn::Optional<std::string &> reasonIfUnsupported) const
181{
Jan Eilers8eb25602020-03-09 12:13:48 +0000182 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100183
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000184 std::array<DataType, 5> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100185 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000186 DataType::BFloat16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100187 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000188 DataType::QAsymmU8,
189 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000190 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100191 };
192
193 bool supported = true;
194
195 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
196 "Reference ArgMinMax: input is not a supported type.");
197 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
198 "Reference ArgMinMax: output type not supported");
199
200 return supported;
201}
202
arovir011c7c81b2018-10-08 11:34:28 +0100203bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
204 const TensorInfo& output,
205 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100206 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100207 const TensorInfo& beta,
208 const TensorInfo& gamma,
209 const BatchNormalizationDescriptor& descriptor,
210 Optional<std::string&> reasonIfUnsupported) const
211{
Jan Eilers8eb25602020-03-09 12:13:48 +0000212 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100213
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000214 std::array<DataType, 5> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100215 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000216 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100217 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100218 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000219 DataType::QAsymmU8,
220 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100221 };
222
223 bool supported = true;
224
225 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
226 "Reference batch normalization: input is not a supported type.");
227
228 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
229 "Reference batch normalization: output is not a supported type.");
230
231 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
232 "Reference batch normalization: input and output types are mismatched");
233
234 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
235 "Reference batch normalization: mean is not a supported type.");
236
237 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
238 "Reference batch normalization: variance is not a supported type.");
239
240 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
241 "Reference batch normalization: beta is not a supported type.");
242
243 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
244 "Reference batch normalization: gamma is not a supported type.");
245
246 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100247}
248
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000249bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
250 const TensorInfo& output,
251 const BatchToSpaceNdDescriptor& descriptor,
252 Optional<std::string&> reasonIfUnsupported) const
253{
Jan Eilers8eb25602020-03-09 12:13:48 +0000254 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100255
256 bool supported = true;
257
258 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
259 std::string inputTensorStr = "input";
260 std::string outputTensorStr = "output";
261
262 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000263 std::array<DataType,5> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100264 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000265 DataType::BFloat16,
266 DataType::Float32,
267 DataType::Float16,
268 DataType::QAsymmU8,
269 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100270 };
271
272 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
273 "Reference BatchToSpaceNd: input type not supported.");
274
275 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
276 "Reference BatchToSpaceNd: output type not supported.");
277
278 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
279 "Reference BatchToSpaceNd: input and output types mismatched.");
280
281 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
282 reasonIfUnsupported,
283 CreateIncorrectDimensionsErrorMsg(4,
284 output.GetNumDimensions(),
285 batchToSpaceNdLayerStr,
286 outputTensorStr).data());
287
288 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
289 reasonIfUnsupported,
290 CreateIncorrectDimensionsErrorMsg(4,
291 input.GetNumDimensions(),
292 batchToSpaceNdLayerStr,
293 inputTensorStr).data());
294
295 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000296}
297
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100298bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
299 const TensorInfo& input1,
300 const TensorInfo& output,
301 const ComparisonDescriptor& descriptor,
302 Optional<std::string&> reasonIfUnsupported) const
303{
Jan Eilers8eb25602020-03-09 12:13:48 +0000304 IgnoreUnused(descriptor);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100305
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000306 std::array<DataType, 5> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100307 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000308 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100309 DataType::Float32,
310 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000311 DataType::QAsymmU8,
312 DataType::QSymmS16
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100313 };
314
315 bool supported = true;
316 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
317 "Reference comparison: input 0 is not a supported type");
318
319 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
320 "Reference comparison: input 0 and Input 1 types are mismatched");
321
322 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
323 "Reference comparison: output is not of type Boolean");
324
325 return supported;
326}
327
Jim Flynn906f9462019-05-10 13:55:21 +0100328bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
329 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100330 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100331 Optional<std::string&> reasonIfUnsupported) const
332{
Jan Eilers8eb25602020-03-09 12:13:48 +0000333 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100334
335 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000336 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100337 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000338 DataType::BFloat16,
339 DataType::Float32,
340 DataType::Float16,
341 DataType::QAsymmU8,
342 DataType::QAsymmS8,
343 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100344 };
345
346 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
347 "Reference concatenation: output type not supported");
348 for (const TensorInfo* input : inputs)
349 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100350 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100351 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
352 "Reference concatenation: input type not supported");
353
354 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
355 "Reference concatenation: input and output types mismatched.");
356 }
357
358 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100359}
360
arovir011c7c81b2018-10-08 11:34:28 +0100361bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
362 Optional<std::string&> reasonIfUnsupported) const
363{
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000364 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100365 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000366 DataType::BFloat16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100367 DataType::Float32,
368 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000369 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +0000370 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000371 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000372 DataType::QSymmS16
Nina Drozd58ef2c62019-05-16 12:09:18 +0100373 };
374
375 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
376 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100377}
378
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000379bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
380 const TensorInfo& output,
381 Optional<std::string&> reasonIfUnsupported) const
382{
383 bool supported = true;
384
385 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
386 "Reference for ConvertBf16ToFp32 layer: input type not supported");
387
388 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
389 "Reference for ConvertBf16ToFp32 layer: output type not supported");
390
391 return supported;
392}
393
arovir011c7c81b2018-10-08 11:34:28 +0100394bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
395 const TensorInfo& output,
396 Optional<std::string&> reasonIfUnsupported) const
397{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100398 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
399 input.GetDataType(),
400 &TrueFunc<>,
401 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000402 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000403 &FalseFuncI32<>,
404 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100405 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
406 output.GetDataType(),
407 &FalseOutputFuncF16<>,
408 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000409 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000410 &FalseFuncI32<>,
411 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100412}
413
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000414bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
415 const TensorInfo& output,
416 Optional<std::string&> reasonIfUnsupported) const
417{
418 bool supported = true;
419
420 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
421 "Reference for ConvertFp32ToBf16 layer: input type not supported");
422
423 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
424 "Reference for ConvertFp32ToBf16 layer: output type not supported");
425
426 return supported;
427}
428
arovir011c7c81b2018-10-08 11:34:28 +0100429bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
430 const TensorInfo& output,
431 Optional<std::string&> reasonIfUnsupported) const
432{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100433 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
434 input.GetDataType(),
435 &FalseInputFuncF16<>,
436 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000437 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000438 &FalseFuncI32<>,
439 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100440 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
441 output.GetDataType(),
442 &TrueFunc<>,
443 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000444 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000445 &FalseFuncI32<>,
446 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100447}
448
449bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
450 const TensorInfo& output,
451 const Convolution2dDescriptor& descriptor,
452 const TensorInfo& weights,
453 const Optional<TensorInfo>& biases,
454 Optional<std::string&> reasonIfUnsupported) const
455{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100456 bool supported = true;
457
458 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000459 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000460 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000461 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000462 DataType::Float32,
463 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000464 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000465 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000466 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000467 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100468 };
469
470 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000471 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100472
473 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000474 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100475
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100476 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000477 "Reference Convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100478
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000479 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000480 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000481 {
Derek Lambertid466a542020-01-22 15:37:29 +0000482 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000483 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000484 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000485 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000486 DataType::QSymmS8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000487 DataType::QAsymmS8,
Derek Lambertid466a542020-01-22 15:37:29 +0000488 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000489 };
Derek Lambertid466a542020-01-22 15:37:29 +0000490 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000491
492 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000493 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000494 }
495 else
496 {
497 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000498 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000499
500 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000501 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000502 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100503
504 if (biases.has_value())
505 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000506 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000507 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000508 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000509 DataType::Float32,
510 DataType::Float16,
511 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100512 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000513
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100514 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000515 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100516 }
Jan Eilers8eb25602020-03-09 12:13:48 +0000517 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100518
519 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100520}
521
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000522bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
523 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000524 Optional<std::string&> reasonIfUnsupported) const
525{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100526 bool supported = true;
527
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000528 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100529 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000530 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000531 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100532 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000533 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000534 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000535 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000536 DataType::QSymmS16,
537 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100538 };
539
540 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000541 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100542
543 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000544 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100545
546 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000547 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100548
549 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000550}
551
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100552bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
553 const TensorInfo& output,
554 const DepthToSpaceDescriptor& descriptor,
555 Optional<std::string&> reasonIfUnsupported) const
556{
Jan Eilers8eb25602020-03-09 12:13:48 +0000557 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100558 bool supported = true;
559
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000560 std::array<DataType,5> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100561 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000562 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100563 DataType::Float32,
564 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000565 DataType::QAsymmU8,
566 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100567 };
568
569 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
570 "Reference DepthToSpace: input type not supported");
571
572 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
573 "Reference DepthToSpace: output type not supported");
574
575 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
576 "Reference DepthToSpace: input and output types are mismatched");
577
578 return supported;
579}
580
arovir011c7c81b2018-10-08 11:34:28 +0100581bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
582 const TensorInfo& output,
583 const DepthwiseConvolution2dDescriptor& descriptor,
584 const TensorInfo& weights,
585 const Optional<TensorInfo>& biases,
586 Optional<std::string&> reasonIfUnsupported) const
587{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100588 bool supported = true;
589
590 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000591 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100592 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000593 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100594 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100595 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000596 DataType::QSymmS8,
597 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000598 DataType::QAsymmU8,
599 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100600 };
601
602 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
603 "Reference DepthwiseConvolution2d: input is not a supported type.");
604
605 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
606 "Reference DepthwiseConvolution2d: output is not a supported type.");
607
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100608 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
609 "Reference DepthwiseConvolution2d: input and output types mismatched.");
610
Derek Lambertid466a542020-01-22 15:37:29 +0000611 ARMNN_NO_DEPRECATE_WARN_BEGIN
612 std::array<DataType, 3> supportedWeightTypes =
613 {
614 DataType::QAsymmU8,
615 DataType::QSymmS8,
616 DataType::QuantizedSymm8PerAxis // deprecated
617 };
618 ARMNN_NO_DEPRECATE_WARN_END
619
Teresa Charlind8df0262019-11-11 12:28:15 +0000620 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000621 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000622 {
Teresa Charlind8df0262019-11-11 12:28:15 +0000623
624 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
625 "Reference convolution2d: weights type not supported for quantized input.");
626 }
627 else
628 {
629 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
630 "Reference DepthwiseConvolution2d: weights is not a supported type.");
631
632 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
633 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
634 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100635
636 if (biases.has_value())
637 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000638 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100639 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000640 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100641 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100642 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100643 DataType::Signed32
644 };
645 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
646 "Reference DepthwiseConvolution2d: biases is not a supported type.");
647 }
Jan Eilers8eb25602020-03-09 12:13:48 +0000648 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100649
650 return supported;
651
arovir011c7c81b2018-10-08 11:34:28 +0100652}
653
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000654bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
655 const TensorInfo& output,
656 Optional<std::string&> reasonIfUnsupported) const
657{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100658 bool supported = true;
659
Ryan OShea9add1202020-02-07 10:06:33 +0000660 std::array<DataType,4> supportedInputTypes = {
661 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000662 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000663 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000664 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100665 };
666
667 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000668 "Reference for Dequantize layer: input type not supported.");
669
670 supported &= CheckSupportRule( TypeNotPerAxisQuantized(input), reasonIfUnsupported,
671 "Reference for Dequantize layer: per-axis quantized input not support .");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100672
Derek Lambertid466a542020-01-22 15:37:29 +0000673 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
674 "Reference dequantize: per-axis quantized input not support .");
675
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000676 std::array<DataType,3> supportedOutputTypes = {
677 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +0000678 DataType::Float32,
679 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100680 };
681
682 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000683 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100684
685 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000686 "Reference for Dequantize layer: input/output shapes have different num total "
687 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100688
689 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000690}
691
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000692bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
693 const TensorInfo& scores,
694 const TensorInfo& anchors,
695 const TensorInfo& detectionBoxes,
696 const TensorInfo& detectionClasses,
697 const TensorInfo& detectionScores,
698 const TensorInfo& numDetections,
699 const DetectionPostProcessDescriptor& descriptor,
700 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000701{
Jan Eilers8eb25602020-03-09 12:13:48 +0000702 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +0000703
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100704 bool supported = true;
705
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000706 std::array<DataType,4> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100707 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000708 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100709 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000710 DataType::QAsymmU8,
711 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100712 };
713
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000714 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100715 "Reference DetectionPostProcess: input 0 is not a supported type.");
716
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000717 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100718 "Reference DetectionPostProcess: input 1 is not a supported type.");
719
720 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000721}
722
Pablo Tellof0bd6832019-04-26 17:58:13 +0100723bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
724 const TensorInfo& output,
725 const DepthwiseConvolution2dDescriptor& descriptor,
726 const TensorInfo& weights,
727 const Optional<TensorInfo>& biases,
728 Optional<std::string&> reasonIfUnsupported) const
729{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100730 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100731}
732
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100733bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100734 const TensorInfo& input1,
735 const TensorInfo& output,
736 Optional<std::string&> reasonIfUnsupported) const
737{
Sadik Armagan2999a022019-04-09 14:20:12 +0100738 bool supported = true;
739
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000740 std::array<DataType,5> supportedTypes = {
741 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100742 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100743 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000744 DataType::QAsymmU8,
745 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +0100746 };
747
748 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
749 "Reference division: input 0 is not a supported type.");
750
751 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
752 "Reference division: input 1 is not a supported type.");
753
754 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
755 "Reference division: output is not a supported type.");
756
757 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
758 "Reference division: input 0 and Input 1 types are mismatched");
759
760 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
761 "Reference division: input and output types are mismatched");
762
763 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
764 "Reference division: shapes are not suitable for implicit broadcast.");
765
766 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100767}
768
josh minor4a3c6102020-01-06 16:40:46 -0600769bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
770 const TensorInfo& output,
771 const ElementwiseUnaryDescriptor& descriptor,
772 Optional<std::string&> reasonIfUnsupported) const
773{
Jan Eilers8eb25602020-03-09 12:13:48 +0000774 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -0600775
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000776 std::array<DataType, 5> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -0600777 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000778 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -0600779 DataType::Float32,
780 DataType::Float16,
781 DataType::QAsymmU8,
782 DataType::QSymmS16
783 };
784
785 bool supported = true;
786
787 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
788 "Reference elementwise unary: input type not supported");
789
790 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
791 "Reference elementwise unary: output type not supported");
792
793 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
794 "Reference elementwise unary: input and output types not matching");
795
796 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
797 "Reference elementwise unary: input and output shapes"
798 "have different number of total elements");
799
800 return supported;
801}
802
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000803bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
804 const TensorInfo& input1,
805 const TensorInfo& output,
806 Optional<std::string&> reasonIfUnsupported) const
807{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100808 return IsComparisonSupported(input0,
809 input1,
810 output,
811 ComparisonDescriptor(ComparisonOperation::Equal),
812 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000813}
814
arovir011c7c81b2018-10-08 11:34:28 +0100815bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
816 const FakeQuantizationDescriptor& descriptor,
817 Optional<std::string&> reasonIfUnsupported) const
818{
Jan Eilers8eb25602020-03-09 12:13:48 +0000819 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100820 bool supported = true;
821
822 std::array<DataType,1> supportedTypes =
823 {
824 DataType::Float32
825 };
826
827 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
828 "Reference fake quantization: input type not supported.");
829
830 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100831}
832
833bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
834 const TensorInfo& output,
835 Optional<std::string&> reasonIfUnsupported) const
836{
Jan Eilers8eb25602020-03-09 12:13:48 +0000837 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +0100838 bool supported = true;
839
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000840 std::array<DataType,4> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100841 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000842 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +0100843 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100844 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000845 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +0100846 };
847
848 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
849 "Reference Floor: input type not supported.");
850
851 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
852 "Reference Floor: output type not supported.");
853
854 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100855}
856
857bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
858 const TensorInfo& output,
859 const TensorInfo& weights,
860 const TensorInfo& biases,
861 const FullyConnectedDescriptor& descriptor,
862 Optional<std::string&> reasonIfUnsupported) const
863{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100864 bool supported = true;
865
866 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000867 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100868 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000869 DataType::BFloat16,
870 DataType::Float32,
871 DataType::Float16,
872 DataType::QAsymmU8,
873 DataType::QAsymmS8,
874 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100875 };
876
877 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
878 "Reference Fully Connected: input type not supported.");
879
880 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
881 "Reference Fully Connected: output type not supported.");
882
883 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
884 "Reference Fully Connected: input and output types mismatched.");
885
886 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
887 "Reference Fully Connected: weights type not supported.");
888
Francis Murtaghddb1d062020-03-10 13:51:45 +0000889 ARMNN_NO_DEPRECATE_WARN_BEGIN
890 std::array<DataType, 3> supportedWeightTypes =
891 {
892 DataType::QAsymmU8,
893 DataType::QSymmS8,
894 DataType::QuantizedSymm8PerAxis // deprecated
895 };
896 ARMNN_NO_DEPRECATE_WARN_END
897
898 if (IsQuantized8BitType(input.GetDataType()))
899 {
900
901 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
902 "Reference Fully Connected: weights type not supported for quantized input.");
903 }
904 else
905 {
906 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
907 "Reference Fully Connected: weights is not a supported type.");
908
909 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
910 "Reference Fully Connected: input and weights types mismatched.");
911 }
Francis Murtagh46c09d02019-05-28 08:15:28 +0100912
913 if (descriptor.m_BiasEnabled)
914 {
915 // Defined supported types for bias
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000916 std::array<DataType, 4>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100917 supportedBiasTypes =
918 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000919 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100920 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100921 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100922 DataType::Signed32
923 };
924
925 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
926 "Reference Fully Connected: bias type not supported.");
927
928 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
929 "Reference Fully Connected: bias and weight types mismatch.");
930
931 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
932 "Reference Fully Connected: bias type inferred from weights is incompatible.");
933
934 }
935
936 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100937}
938
narpra014951d842019-01-18 16:53:53 +0000939bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
940 const armnn::TensorInfo& input1,
941 const armnn::TensorInfo& output,
942 armnn::Optional<std::string&> reasonIfUnsupported) const
943{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100944 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000945 std::array<DataType,5> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100946 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000947 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100948 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100949 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000950 DataType::QAsymmU8,
951 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100952 };
953
954 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
955 "Reference Gather: input type not supported");
956
957 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
958 "Reference Gather: output type not supported");
959
960 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
961 "Reference Gather: indices (input1) type not supported");
962
963 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
964 "Reference Gather: input and output types not matching");
965
966 return supported;
narpra014951d842019-01-18 16:53:53 +0000967}
968
FrancisMurtagh878f0232018-12-19 10:56:15 +0000969bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
970 const TensorInfo& input1,
971 const TensorInfo& output,
972 Optional<std::string&> reasonIfUnsupported) const
973{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100974 return IsComparisonSupported(input0,
975 input1,
976 output,
977 ComparisonDescriptor(ComparisonOperation::Greater),
978 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000979}
980
Derek Lamberti901ea112019-12-10 22:07:09 +0000981bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
982 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +0100983{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100984 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100985}
986
Kevin May09ca49c2019-10-09 12:37:34 +0100987bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
988 const TensorInfo& output,
989 const InstanceNormalizationDescriptor& descriptor,
990 Optional<std::string&> reasonIfUnsupported) const
991{
Jan Eilers8eb25602020-03-09 12:13:48 +0000992 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +0100993 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000994 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +0100995 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000996 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +0100997 DataType::Float32,
998 DataType::Float16
999 };
1000
1001 bool supported = true;
1002
1003 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1004 "Reference Instance Normalization: input type not supported.");
1005
1006 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1007 "Reference Instance Normalization: output type not supported.");
1008
1009 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1010 "Reference Instance Normalization: input and output types mismatched.");
1011
1012 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1013 "Reference Instance Normalization: input and output shapes have different "
1014 "num total elements.");
1015
1016 return supported;
1017}
1018
arovir011c7c81b2018-10-08 11:34:28 +01001019bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1020 const TensorInfo& output,
1021 const L2NormalizationDescriptor& descriptor,
1022 Optional<std::string&> reasonIfUnsupported) const
1023{
Jan Eilers8eb25602020-03-09 12:13:48 +00001024 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001025 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001026 std::array<DataType, 5> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001027 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001028 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001029 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001030 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001031 DataType::QAsymmU8,
1032 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001033 };
1034
1035 bool supported = true;
1036
1037 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1038 "Reference L2normalization: input type not supported.");
1039
1040 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1041 "Reference L2normalization: output type not supported.");
1042
1043 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1044 "Reference L2normalization: input and output types mismatched.");
1045
1046 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1047 "Reference L2normalization: input and output shapes have different "
1048 "num total elements.");
1049
1050 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001051}
1052
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001053bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1054 const TensorInfo& output,
1055 const LogSoftmaxDescriptor& descriptor,
1056 Optional<std::string&> reasonIfUnsupported) const
1057{
Jan Eilers8eb25602020-03-09 12:13:48 +00001058 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001059
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001060 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001061 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001062 DataType::BFloat16,
1063 DataType::Float32,
1064 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001065 };
1066
1067 bool supported = true;
1068 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1069 "Reference LogSoftmax: input type not supported");
1070
1071 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1072 "Reference LogSoftmax: output type not supported");
1073
1074 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1075 "Reference LogSoftmax: input and output types do not match");
1076
1077 return supported;
1078}
1079
arovir011c7c81b2018-10-08 11:34:28 +01001080bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1081 const TensorInfo& outputStateIn,
1082 const TensorInfo& cellStateIn,
1083 const TensorInfo& scratchBuffer,
1084 const TensorInfo& outputStateOut,
1085 const TensorInfo& cellStateOut,
1086 const TensorInfo& output,
1087 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001088 const LstmInputParamsInfo& paramsInfo,
1089 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001090{
Jan Eilers8eb25602020-03-09 12:13:48 +00001091 IgnoreUnused(descriptor);
1092 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001093
1094 bool supported = true;
1095
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001096 std::array<DataType,3> supportedTypes = {
1097 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001098 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001099 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001100 };
1101
Jan Eilersd01a83c2019-07-03 18:20:40 +01001102 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001103 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1104 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001105 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1106 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001107 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1108 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001109 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1110 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001111 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1112 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001113 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1114 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001115 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1116 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001117 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001118 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001119 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001120 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001121 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001122 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001123 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001124 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001125 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001126 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001127 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001128 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001129 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001130 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001131 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001132 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001133 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001134 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001135 "Reference Lstm: input and OutputGateBias types are mismatched");
1136 if (!descriptor.m_CifgEnabled)
1137 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001138 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001139 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001140 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001141 reasonIfUnsupported,
1142 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001143 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001144 "Reference Lstm: input and InputGateBias types are mismatched");
1145 if (descriptor.m_PeepholeEnabled)
1146 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001147 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001148 reasonIfUnsupported,
1149 "Reference Lstm: input and CellToInputWeights types are mismatched");
1150 }
1151 }
1152 if (descriptor.m_PeepholeEnabled)
1153 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001154 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001155 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001156 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001157 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1158 }
1159 if (descriptor.m_ProjectionEnabled)
1160 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001161 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001162 "Reference Lstm: input and mProjectionWeights types are mismatched");
1163 if (paramsInfo.m_ProjectionBias != nullptr)
1164 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001165 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001166 "Reference Lstm: input and ProjectionBias types are mismatched");
1167 }
1168 }
1169 if (descriptor.m_LayerNormEnabled)
1170 {
1171 if (!descriptor.m_CifgEnabled)
1172 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001173 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001174 reasonIfUnsupported,
1175 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1176 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001177 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001178 reasonIfUnsupported,
1179 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001180 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001181 reasonIfUnsupported,
1182 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001183 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001184 reasonIfUnsupported,
1185 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1186 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001187
1188 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001189}
1190
saoste012df12b32018-11-28 16:57:20 +00001191bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1192 const TensorInfo& input1,
1193 const TensorInfo& output,
1194 Optional<std::string&> reasonIfUnsupported) const
1195{
Sadik Armagan2999a022019-04-09 14:20:12 +01001196 bool supported = true;
1197
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001198 std::array<DataType,6> supportedTypes = {
1199 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001200 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001201 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001202 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001203 DataType::QAsymmU8,
1204 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001205 };
1206
1207 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1208 "Reference maximum: input 0 is not a supported type.");
1209
1210 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1211 "Reference maximum: input 1 is not a supported type.");
1212
1213 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1214 "Reference maximum: output is not a supported type.");
1215
1216 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1217 "Reference maximum: input 0 and Input 1 types are mismatched");
1218
1219 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1220 "Reference maximum: input and output types are mismatched");
1221
1222 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1223 "Reference maximum: shapes are not suitable for implicit broadcast.");
1224
1225 return supported;
saoste012df12b32018-11-28 16:57:20 +00001226}
1227
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001228bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1229 const TensorInfo& output,
1230 const MeanDescriptor& descriptor,
1231 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001232{
James Conroy4d1ff582019-06-10 17:06:39 +01001233 bool supported = true;
1234 std::string meanLayerStr = "Mean";
1235 std::string outputTensorStr = "output";
1236
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001237 std::array<DataType,5> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001238 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001239 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001240 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001241 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001242 DataType::QAsymmU8,
1243 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001244 };
1245
1246 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1247 "Reference Mean: input type not supported.");
1248
1249 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1250 "Reference Mean: input and output types are mismatched");
1251
1252 if (descriptor.m_KeepDims)
1253 {
1254 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1255 reasonIfUnsupported,
1256 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1257 output.GetNumDimensions(),
1258 meanLayerStr, outputTensorStr).data());
1259 }
1260 else if (descriptor.m_Axis.empty())
1261 {
1262 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1263 reasonIfUnsupported,
1264 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1265 meanLayerStr, outputTensorStr).data());
1266 }
1267 else
1268 {
1269 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1270
1271 if (outputDim > 0)
1272 {
1273 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1274 reasonIfUnsupported,
1275 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1276 meanLayerStr, outputTensorStr).data());
1277 }
1278 else
1279 {
1280 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1281 reasonIfUnsupported,
1282 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1283 meanLayerStr, outputTensorStr).data());
1284 }
1285 }
1286
1287 return supported;
narpra0132b90462018-09-13 11:07:48 +01001288}
1289
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001290bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001291 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001292 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001293 Optional<std::string&> reasonIfUnsupported) const
1294{
Jim Flynne242f2d2019-05-22 14:24:13 +01001295 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001296}
1297
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001298bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1299 const TensorInfo &output,
1300 Optional<std::string &> reasonIfUnsupported) const
1301{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001302 bool supported = true;
1303
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001304 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001305 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001306 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001307 DataType::Float32,
1308 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001309 DataType::QAsymmU8,
1310 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001311 DataType::Boolean
1312 };
1313
1314 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1315 "Reference MemCopy: input type not supported");
1316
1317 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1318 "Reference MemCopy: output type not supported");
1319
1320 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1321 "Reference MemCopy: input and output types are mismatched");
1322
1323 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001324}
1325
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001326bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1327 const TensorInfo& input1,
1328 const TensorInfo& output,
1329 Optional<std::string&> reasonIfUnsupported) const
1330{
Sadik Armagan2999a022019-04-09 14:20:12 +01001331 bool supported = true;
1332
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001333 std::array<DataType,5> supportedTypes = {
1334 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001335 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001336 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001337 DataType::QAsymmU8,
1338 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001339 };
1340
1341 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1342 "Reference minimum: input 0 is not a supported type.");
1343
1344 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1345 "Reference minimum: input 1 is not a supported type.");
1346
1347 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1348 "Reference minimum: output is not a supported type.");
1349
1350 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1351 "Reference minimum: input 0 and Input 1 types are mismatched");
1352
1353 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1354 "Reference minimum: input and output types are mismatched");
1355
1356 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1357 "Reference minimum: shapes are not suitable for implicit broadcast.");
1358
1359 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001360}
1361
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001362bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1363 const TensorInfo& input1,
1364 const TensorInfo& output,
1365 Optional<std::string&> reasonIfUnsupported) const
1366{
Sadik Armagan2999a022019-04-09 14:20:12 +01001367 bool supported = true;
1368
Keith Davis67e6c542020-02-19 10:08:33 +00001369 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001370 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001371 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001372 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001373 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001374 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001375 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001376 };
1377
1378 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1379 "Reference multiplication: input 0 is not a supported type.");
1380
1381 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1382 "Reference multiplication: input 1 is not a supported type.");
1383
1384 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1385 "Reference multiplication: output is not a supported type.");
1386
1387 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1388 "Reference multiplication: input 0 and Input 1 types are mismatched");
1389
1390 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1391 "Reference multiplication: input and output types are mismatched");
1392
1393 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1394 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1395
1396 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001397}
1398
1399bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1400 const TensorInfo& output,
1401 const NormalizationDescriptor& descriptor,
1402 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001403{
Jan Eilers8eb25602020-03-09 12:13:48 +00001404 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001405
1406 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001407 std::array<DataType, 5> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001408 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001409 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001410 DataType::Float16,
1411 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001412 DataType::QAsymmU8,
1413 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001414 };
1415
1416 bool supported = true;
1417
1418 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1419 "Reference normalization: input type not supported.");
1420
1421 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1422 "Reference normalization: output type not supported.");
1423
1424 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1425 "Reference normalization: input and output shapes have different "
1426 "num total elements.");
1427
1428 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001429}
1430
Derek Lamberti901ea112019-12-10 22:07:09 +00001431bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1432 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001433{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001434 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001435}
1436
1437bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1438 const TensorInfo& output,
1439 const PadDescriptor& descriptor,
1440 Optional<std::string&> reasonIfUnsupported) const
1441{
Jan Eilers8eb25602020-03-09 12:13:48 +00001442 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001443 bool supported = true;
1444
1445 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001446 std::array<DataType,5> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001447 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001448 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001449 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001450 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001451 DataType::QAsymmU8,
1452 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001453 };
1454
1455 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1456 "Reference pad: input is not a supported type.");
1457
1458 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1459 "Reference pad: output is not a supported type.");
1460
1461 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1462 "Reference pad: input and output types are mismatched.");
1463
1464 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001465}
1466
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001467bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1468 const TensorInfo& output,
1469 const PermuteDescriptor& descriptor,
1470 Optional<std::string&> reasonIfUnsupported) const
1471{
Jan Eilers8eb25602020-03-09 12:13:48 +00001472 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001473 bool supported = true;
1474
1475 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001476 std::array<DataType, 5> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001477 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001478 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001479 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001480 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001481 DataType::QAsymmU8,
1482 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001483 };
1484
1485 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1486 "Reference permute: input is not a supported type.");
1487
1488 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1489 "Reference permute: output is not a supported type.");
1490
1491 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1492 "Reference permute: input and output types are mismatched.");
1493
1494 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001495}
1496
1497bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1498 const TensorInfo& output,
1499 const Pooling2dDescriptor& descriptor,
1500 Optional<std::string&> reasonIfUnsupported) const
1501{
Jan Eilers8eb25602020-03-09 12:13:48 +00001502 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001503 bool supported = true;
1504
1505 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001506 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001507 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001508 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001509 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001510 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001511 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001512 DataType::QAsymmU8,
1513 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001514 };
1515
1516 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1517 "Reference poolind2d: input is not a supported type.");
1518
1519 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1520 "Reference poolind2d: output is not a supported type.");
1521
1522 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1523 "Reference poolind2d: input and output types are mismatched.");
1524
1525 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001526}
1527
Derek Lamberti5f400d62019-03-25 15:41:58 +00001528bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1529 const TensorInfo& output,
1530 Optional<std::string&> reasonIfUnsupported) const
1531{
1532 bool supported = true;
1533
Finn Williamsfd271062019-12-04 14:27:27 +00001534 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001535 std::array<DataType,7> supportedInputTypes = {
1536 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00001537 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001538 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001539 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001540 DataType::QAsymmU8,
1541 DataType::QSymmS8,
1542 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001543 };
1544
1545 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1546 "Reference quantize: input type not supported.");
1547
1548 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001549 std::array<DataType,4> supportedOutputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001550 DataType::QAsymmU8,
Ryan OShea9add1202020-02-07 10:06:33 +00001551 DataType::QAsymmS8,
Finn Williamsfd271062019-12-04 14:27:27 +00001552 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001553 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001554 };
1555 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1556 "Reference quantize: output type not supported.");
1557
1558 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1559 "Reference quantize: input and output shapes have different num total elements.");
1560
1561 return supported;
1562}
1563
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001564bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001565 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001566 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001567 Optional<std::string&> reasonIfUnsupported) const
1568{
Jan Eilers8eb25602020-03-09 12:13:48 +00001569 IgnoreUnused(output);
1570 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001571 // Define supported output types.
Keith Davis0c2eeac2020-02-11 16:51:50 +00001572 std::array<DataType,7> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001573 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001574 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001575 DataType::Float32,
1576 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001577 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001578 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001579 DataType::QAsymmU8,
1580 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001581 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001582
Nina Drozd2f2778f2019-05-27 10:37:05 +01001583 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1584 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001585}
1586
1587bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001588 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001589 Optional<std::string&> reasonIfUnsupported) const
1590{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001591 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001592 std::array<DataType,5> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001593 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001594 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001595 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001596 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001597 DataType::QAsymmU8,
1598 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001599 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001600
1601 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1602 "Reference ResizeBilinear: input type not supported");
1603
1604 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1605 "Reference ResizeBilinear: output type not supported");
1606
1607 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1608 "Reference ResizeBilinear: input and output types not matching");
1609
1610 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001611}
1612
Teresa Charlin970f43b2019-07-01 13:51:07 +01001613bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1614 const TensorInfo& output,
1615 const ResizeDescriptor& descriptor,
1616 Optional<std::string&> reasonIfUnsupported) const
1617{
Jan Eilers8eb25602020-03-09 12:13:48 +00001618 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001619 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001620 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001621 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001622 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001623 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001624 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001625 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001626 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001627 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001628 };
1629
1630 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1631 "Reference Resize: input type not supported");
1632
1633 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1634 "Reference Resize: output type not supported");
1635
1636 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1637 "Reference Resize: input and output types not matching");
1638
1639 return supported;
1640}
1641
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001642bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1643 const TensorInfo& output,
1644 Optional<std::string&> reasonIfUnsupported) const
1645{
josh minor4a3c6102020-01-06 16:40:46 -06001646 return IsElementwiseUnarySupported(input,
1647 output,
1648 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1649 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001650}
1651
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001652bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1653 const TensorInfo& output,
1654 const SliceDescriptor& descriptor,
1655 Optional<std::string&> reasonIfUnsupported) const
1656{
Jan Eilers8eb25602020-03-09 12:13:48 +00001657 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001658 bool supported = true;
1659
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001660 std::array<DataType, 4> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001661 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001662 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001663 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001664 DataType::QAsymmU8,
1665 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001666 };
1667
1668 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1669 "Reference Slice: input type not supported");
1670
1671 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1672 "Reference Slice: output type not supported");
1673
1674 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1675 "Reference Slice: input and output types are mismatched");
1676
1677 return supported;
1678}
1679
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001680bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1681 const TensorInfo& output,
1682 const SoftmaxDescriptor& descriptor,
1683 Optional<std::string&> reasonIfUnsupported) const
1684{
Jan Eilers8eb25602020-03-09 12:13:48 +00001685 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001686 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001687 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001688 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001689 DataType::BFloat16,
1690 DataType::Float32,
1691 DataType::Float16,
1692 DataType::QSymmS8,
1693 DataType::QAsymmS8,
1694 DataType::QAsymmU8,
1695 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001696 };
1697
1698 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001699 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001700
1701 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001702 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001703
1704 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001705 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001706
1707 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001708}
1709
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001710bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1711 const TensorInfo& output,
1712 const SpaceToBatchNdDescriptor& descriptor,
1713 Optional<std::string&> reasonIfUnsupported) const
1714{
Jan Eilers8eb25602020-03-09 12:13:48 +00001715 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001716 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001717 std::array<DataType,5> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001718 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001719 DataType::BFloat16,
1720 DataType::Float32,
1721 DataType::Float16,
1722 DataType::QAsymmU8,
1723 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001724 };
1725
1726 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1727 "Reference SpaceToBatchNd: input type not supported");
1728
1729 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1730 "Reference SpaceToBatchNd: output type not supported");
1731
1732 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1733 "Reference SpaceToBatchNd: input and output types are mismatched");
1734
1735 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001736}
1737
Keith Davisa57eccb2019-06-14 17:33:22 +01001738bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001739 const TensorInfo& output,
1740 const SpaceToDepthDescriptor& descriptor,
1741 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001742{
1743
Jan Eilers8eb25602020-03-09 12:13:48 +00001744 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01001745 bool supported = true;
1746
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001747 std::array<DataType,5> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001748 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001749 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001750 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001751 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001752 DataType::QAsymmU8,
1753 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001754 };
1755
1756 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1757 "Reference SpaceToDepth: input type not supported");
1758
1759 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1760 "Reference SpaceToDepth: output type not supported");
1761
1762 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1763 "Reference SpaceToDepth: input and output types are mismatched");
1764
1765 return supported;
1766}
1767
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001768bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1769 const ViewsDescriptor& descriptor,
1770 Optional<std::string&> reasonIfUnsupported) const
1771{
Jan Eilers8eb25602020-03-09 12:13:48 +00001772 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001773 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001774 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001775 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001776 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001777 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001778 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001779 DataType::QAsymmU8,
1780 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001781 };
1782
1783 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1784 "Reference splitter: input type not supported");
1785
1786 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001787}
1788
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001789bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1790 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1791 const ViewsDescriptor& descriptor,
1792 Optional<std::string&> reasonIfUnsupported) const
1793{
Jan Eilers8eb25602020-03-09 12:13:48 +00001794 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001795 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001796 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001797 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001798 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001799 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001800 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001801 DataType::QAsymmU8,
1802 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001803 };
1804
1805 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1806 "Reference splitter: output type not supported");
1807 for (const TensorInfo output : outputs)
1808 {
1809 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1810 "Reference splitter: input type not supported");
1811
1812 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1813 "Reference splitter: input and output types mismatched.");
1814 }
1815
1816 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001817}
1818
Matthew Jackson81e601c2019-07-11 12:07:09 +01001819bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1820 const TensorInfo& output,
1821 const StackDescriptor& descriptor,
1822 Optional<std::string&> reasonIfUnsupported) const
1823{
Jan Eilers8eb25602020-03-09 12:13:48 +00001824 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001825
1826 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001827 std::array<DataType,5> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001828 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001829 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001830 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001831 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001832 DataType::QAsymmU8,
1833 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01001834 };
1835
1836 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1837 "Reference stack: output type not supported");
1838 for (const TensorInfo* input : inputs)
1839 {
1840 BOOST_ASSERT(input != nullptr);
1841 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1842 "Reference stack: input type not supported");
1843
1844 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1845 "Reference stack: input and output types mismatched.");
1846 }
1847
1848 return supported;
1849}
1850
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001851bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1852 const TensorInfo& output,
1853 const StridedSliceDescriptor& descriptor,
1854 Optional<std::string&> reasonIfUnsupported) const
1855{
Jan Eilers8eb25602020-03-09 12:13:48 +00001856 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001857 bool supported = true;
1858
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001859 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001860 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001861 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001862 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001863 DataType::QAsymmU8,
1864 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001865 };
1866
1867 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1868 "Reference StridedSlice: input type not supported");
1869
1870 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1871 "Reference StridedSlice: output type not supported");
1872
1873 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1874 "Reference StridedSlice: input and output types are mismatched");
1875
1876 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001877}
1878
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001879bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1880 const TensorInfo& input1,
1881 const TensorInfo& output,
1882 Optional<std::string&> reasonIfUnsupported) const
1883{
Sadik Armagan2999a022019-04-09 14:20:12 +01001884 bool supported = true;
1885
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001886 std::array<DataType,5> supportedTypes = {
1887 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001888 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001889 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001890 DataType::QAsymmU8,
1891 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001892 };
1893
1894 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1895 "Reference subtraction: input 0 is not a supported type.");
1896
1897 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1898 "Reference subtraction: input 1 is not a supported type.");
1899
1900 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1901 "Reference subtraction: output is not a supported type.");
1902
1903 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1904 "Reference subtraction: input 0 and Input 1 types are mismatched");
1905
1906 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1907 "Reference subtraction: input and output types are mismatched");
1908
1909 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1910 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1911
1912 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001913}
1914
Matteo Martincighab9e5252019-06-13 17:27:46 +01001915bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1916 const TensorInfo& alpha,
1917 const TensorInfo& output,
1918 Optional<std::string&> reasonIfUnsupported) const
1919{
1920 bool supported = true;
1921
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001922 std::array<DataType, 5> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001923 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001924 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01001925 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001926 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001927 DataType::QAsymmU8,
1928 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01001929 };
1930
1931 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1932 "PReLU: input is not a supported type.");
1933
1934 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1935 "PReLU: alpha is not a supported type.");
1936
1937 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1938 "PReLU: output is not a supported type.");
1939
1940 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1941 "PReLU: input, alpha and output types are mismatched");
1942
1943 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1944 "PReLU: shapes are not suitable for implicit broadcast");
1945
1946 return supported;
1947}
1948
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001949bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1950 const TensorInfo& output,
1951 const TransposeConvolution2dDescriptor& descriptor,
1952 const TensorInfo& weights,
1953 const Optional<TensorInfo>& biases,
1954 Optional<std::string&> reasonIfUnsupported) const
1955{
Jan Eilers8eb25602020-03-09 12:13:48 +00001956 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001957 bool supported = true;
1958
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001959 std::array<DataType,5> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001960 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001961 DataType::BFloat16,
1962 DataType::Float32,
1963 DataType::Float16,
1964 DataType::QAsymmU8,
1965 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001966 };
1967
1968 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1969 "Reference TransposeConvolution2d: input is not a supported type.");
1970
1971 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1972 "Reference TransposeConvolution2d: output is not a supported type.");
1973
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001974 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1975 "Reference TransposeConvolution2d: input and output types mismatched.");
1976
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001977
1978 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +00001979 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001980 {
Derek Lambertid466a542020-01-22 15:37:29 +00001981 ARMNN_NO_DEPRECATE_WARN_BEGIN
1982 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001983 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001984 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00001985 DataType::QSymmS8,
1986 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001987 };
Derek Lambertid466a542020-01-22 15:37:29 +00001988 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001989
1990 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1991 "Reference TransposeConvolution2d: weights type not supported for "
1992 "quantized input.");
1993 }
1994 else
1995 {
1996 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1997 "Reference TransposeConvolution2d: weights is not a supported type.");
1998
1999 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2000 "Reference TransposeConvolution2d: input and weights types mismatched.");
2001 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002002
2003 if (biases.has_value())
2004 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002005 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002006 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002007 DataType::BFloat16,
2008 DataType::Float32,
2009 DataType::Float16,
2010 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002011 };
2012 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2013 "Reference TransposeConvolution2d: biases is not a supported type.");
2014 }
2015
2016 return supported;
2017}
2018
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002019bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2020 const TensorInfo& output,
2021 const TransposeDescriptor& descriptor,
2022 Optional<std::string&> reasonIfUnsupported) const
2023{
Jan Eilers8eb25602020-03-09 12:13:48 +00002024 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002025 bool supported = true;
2026
2027 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002028 std::array<DataType, 5> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002029 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002030 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002031 DataType::Float32,
2032 DataType::Float16,
2033 DataType::QAsymmU8,
2034 DataType::QSymmS16
2035 };
2036
2037 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2038 "Reference transpose: input is not a supported type.");
2039
2040 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2041 "Reference transpose: output is not a supported type.");
2042
2043 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2044 "Reference transpose: input and output types are mismatched.");
2045
2046 return supported;
2047}
2048
arovir011c7c81b2018-10-08 11:34:28 +01002049} // namespace armnn