blob: 607c86b11237b93812dab23425c80716050f2252 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matteo Martincighe011d202019-11-28 11:35:47 +000013#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010014#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000015
Matteo Martincighe011d202019-11-28 11:35:47 +000016#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <array>
20
telsoa014fcda012018-03-09 14:13:49 +000021using namespace boost;
22
23namespace armnn
24{
25
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010026namespace
27{
28
29template<typename Float32Func, typename Uint8Func, typename ... Params>
30bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
31 DataType dataType,
32 Float32Func floatFuncPtr,
33 Uint8Func uint8FuncPtr,
34 Params&&... params)
35{
36 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
37 dataType,
38 &FalseFunc<Params...>,
39 floatFuncPtr,
40 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000041 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000042 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010043 std::forward<Params>(params)...);
44}
45
46} // anonymous namespace
47
James Conroy4d1ff582019-06-10 17:06:39 +010048namespace
49{
50
51std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
52 unsigned int actual,
53 std::string& layerStr,
54 std::string& tensorName)
55{
56 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
57 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
58
59 return errorMsg;
60}
61
62} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000063
Sadik Armagan9199e582019-09-05 17:35:31 +010064bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
65 Optional<std::string&> reasonIfUnsupported) const
66{
josh minor4a3c6102020-01-06 16:40:46 -060067 return IsElementwiseUnarySupported(input,
68 output,
69 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
70 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010071}
72
arovir011c7c81b2018-10-08 11:34:28 +010073bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
74 const TensorInfo& output,
75 const ActivationDescriptor& descriptor,
76 Optional<std::string&> reasonIfUnsupported) const
77{
Derek Lamberti50db4e82019-03-13 14:16:15 +000078 bool supported = true;
79
80 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000081 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +000082 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +000083 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010084 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000085 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000086 DataType::QAsymmU8,
87 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000088 };
89
90 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
91 "Reference activation: input type not supported.");
92
93 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
94 "Reference activation: output type not supported.");
95
96 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
97 "Reference activation: input and output types mismatched.");
98
99 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
100 "Reference activation: input and output shapes are of different rank.");
101
102
103 struct ActivationFunctionSupported : public Rule
104 {
105 ActivationFunctionSupported(const ActivationDescriptor& desc)
106 {
107 switch(desc.m_Function)
108 {
109 case ActivationFunction::Abs:
110 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000111 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000112 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000113 case ActivationFunction::LeakyReLu:
114 case ActivationFunction::Linear:
115 case ActivationFunction::ReLu:
116 case ActivationFunction::Sigmoid:
117 case ActivationFunction::SoftReLu:
118 case ActivationFunction::Sqrt:
119 case ActivationFunction::Square:
120 case ActivationFunction::TanH:
121 {
122 m_Res = true;
123 break;
124 }
125 default:
126 {
127 m_Res = false;
128 break;
129 }
130 }
131 }
132 };
133
134 // Function is supported
135 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
136 "Reference activation: function not supported.");
137
138 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100139}
140
141bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
142 const TensorInfo& input1,
143 const TensorInfo& output,
144 Optional<std::string&> reasonIfUnsupported) const
145{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000146 bool supported = true;
147
Keith Davis0c2eeac2020-02-11 16:51:50 +0000148 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000149 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000150 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100151 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000152 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000153 DataType::QAsymmU8,
154 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000155 };
156
157 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
158 "Reference addition: input 0 is not a supported type.");
159
160 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
161 "Reference addition: input 1 is not a supported type.");
162
163 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
164 "Reference addition: output is not a supported type.");
165
166 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
167 "Reference addition: input 0 and Input 1 types are mismatched");
168
169 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
170 "Reference addition: input and output types are mismatched");
171
172 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
173 "Reference addition: shapes are not suitable for implicit broadcast.");
174
175 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100176}
177
Nikhil Raj68c2c902019-09-19 11:21:11 +0100178bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
179 const armnn::ArgMinMaxDescriptor &descriptor,
180 armnn::Optional<std::string &> reasonIfUnsupported) const
181{
Jan Eilers8eb25602020-03-09 12:13:48 +0000182 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100183
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000184 std::array<DataType, 5> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100185 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000186 DataType::BFloat16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100187 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000188 DataType::QAsymmU8,
189 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000190 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100191 };
192
193 bool supported = true;
194
195 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
196 "Reference ArgMinMax: input is not a supported type.");
197 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
198 "Reference ArgMinMax: output type not supported");
199
200 return supported;
201}
202
arovir011c7c81b2018-10-08 11:34:28 +0100203bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
204 const TensorInfo& output,
205 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100206 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100207 const TensorInfo& beta,
208 const TensorInfo& gamma,
209 const BatchNormalizationDescriptor& descriptor,
210 Optional<std::string&> reasonIfUnsupported) const
211{
Jan Eilers8eb25602020-03-09 12:13:48 +0000212 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100213
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000214 std::array<DataType, 5> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100215 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000216 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100217 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100218 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000219 DataType::QAsymmU8,
220 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100221 };
222
223 bool supported = true;
224
225 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
226 "Reference batch normalization: input is not a supported type.");
227
228 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
229 "Reference batch normalization: output is not a supported type.");
230
231 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
232 "Reference batch normalization: input and output types are mismatched");
233
234 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
235 "Reference batch normalization: mean is not a supported type.");
236
237 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
238 "Reference batch normalization: variance is not a supported type.");
239
240 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
241 "Reference batch normalization: beta is not a supported type.");
242
243 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
244 "Reference batch normalization: gamma is not a supported type.");
245
246 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100247}
248
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000249bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
250 const TensorInfo& output,
251 const BatchToSpaceNdDescriptor& descriptor,
252 Optional<std::string&> reasonIfUnsupported) const
253{
Jan Eilers8eb25602020-03-09 12:13:48 +0000254 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100255
256 bool supported = true;
257
258 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
259 std::string inputTensorStr = "input";
260 std::string outputTensorStr = "output";
261
262 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000263 std::array<DataType,5> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100264 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000265 DataType::BFloat16,
266 DataType::Float32,
267 DataType::Float16,
268 DataType::QAsymmU8,
269 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100270 };
271
272 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
273 "Reference BatchToSpaceNd: input type not supported.");
274
275 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
276 "Reference BatchToSpaceNd: output type not supported.");
277
278 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
279 "Reference BatchToSpaceNd: input and output types mismatched.");
280
281 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
282 reasonIfUnsupported,
283 CreateIncorrectDimensionsErrorMsg(4,
284 output.GetNumDimensions(),
285 batchToSpaceNdLayerStr,
286 outputTensorStr).data());
287
288 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
289 reasonIfUnsupported,
290 CreateIncorrectDimensionsErrorMsg(4,
291 input.GetNumDimensions(),
292 batchToSpaceNdLayerStr,
293 inputTensorStr).data());
294
295 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000296}
297
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100298bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
299 const TensorInfo& input1,
300 const TensorInfo& output,
301 const ComparisonDescriptor& descriptor,
302 Optional<std::string&> reasonIfUnsupported) const
303{
Jan Eilers8eb25602020-03-09 12:13:48 +0000304 IgnoreUnused(descriptor);
Sadik Armaganb60dd242020-03-19 13:53:16 +0000305 std::array<DataType, 7> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100306 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000307 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000308 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100309 DataType::Float32,
310 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000311 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000312 DataType::QSymmS16,
313 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100314 };
315
316 bool supported = true;
317 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
318 "Reference comparison: input 0 is not a supported type");
319
320 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
321 "Reference comparison: input 0 and Input 1 types are mismatched");
322
323 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
324 "Reference comparison: output is not of type Boolean");
325
326 return supported;
327}
328
Jim Flynn906f9462019-05-10 13:55:21 +0100329bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
330 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100331 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100332 Optional<std::string&> reasonIfUnsupported) const
333{
Jan Eilers8eb25602020-03-09 12:13:48 +0000334 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100335
336 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000337 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100338 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000339 DataType::BFloat16,
340 DataType::Float32,
341 DataType::Float16,
342 DataType::QAsymmU8,
343 DataType::QAsymmS8,
344 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100345 };
346
347 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
348 "Reference concatenation: output type not supported");
349 for (const TensorInfo* input : inputs)
350 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100351 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100352 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
353 "Reference concatenation: input type not supported");
354
355 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
356 "Reference concatenation: input and output types mismatched.");
357 }
358
359 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100360}
361
arovir011c7c81b2018-10-08 11:34:28 +0100362bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
363 Optional<std::string&> reasonIfUnsupported) const
364{
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000365 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100366 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000367 DataType::BFloat16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100368 DataType::Float32,
369 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000370 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +0000371 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000372 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000373 DataType::QSymmS16
Nina Drozd58ef2c62019-05-16 12:09:18 +0100374 };
375
376 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
377 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100378}
379
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000380bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
381 const TensorInfo& output,
382 Optional<std::string&> reasonIfUnsupported) const
383{
384 bool supported = true;
385
386 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
387 "Reference for ConvertBf16ToFp32 layer: input type not supported");
388
389 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
390 "Reference for ConvertBf16ToFp32 layer: output type not supported");
391
392 return supported;
393}
394
arovir011c7c81b2018-10-08 11:34:28 +0100395bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
396 const TensorInfo& output,
397 Optional<std::string&> reasonIfUnsupported) const
398{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100399 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
400 input.GetDataType(),
401 &TrueFunc<>,
402 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000403 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000404 &FalseFuncI32<>,
405 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100406 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
407 output.GetDataType(),
408 &FalseOutputFuncF16<>,
409 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000410 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000411 &FalseFuncI32<>,
412 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100413}
414
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000415bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
416 const TensorInfo& output,
417 Optional<std::string&> reasonIfUnsupported) const
418{
419 bool supported = true;
420
421 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
422 "Reference for ConvertFp32ToBf16 layer: input type not supported");
423
424 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
425 "Reference for ConvertFp32ToBf16 layer: output type not supported");
426
427 return supported;
428}
429
arovir011c7c81b2018-10-08 11:34:28 +0100430bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
431 const TensorInfo& output,
432 Optional<std::string&> reasonIfUnsupported) const
433{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100434 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
435 input.GetDataType(),
436 &FalseInputFuncF16<>,
437 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000438 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000439 &FalseFuncI32<>,
440 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100441 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
442 output.GetDataType(),
443 &TrueFunc<>,
444 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000445 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000446 &FalseFuncI32<>,
447 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100448}
449
450bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
451 const TensorInfo& output,
452 const Convolution2dDescriptor& descriptor,
453 const TensorInfo& weights,
454 const Optional<TensorInfo>& biases,
455 Optional<std::string&> reasonIfUnsupported) const
456{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100457 bool supported = true;
458
459 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000460 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000461 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000462 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000463 DataType::Float32,
464 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000465 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000466 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000467 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000468 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100469 };
470
471 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000472 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100473
474 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000475 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100476
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000477 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
478 if (input.GetDataType() == DataType::BFloat16)
479 {
480 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
481 {
482 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
483 supported = false;
484 }
485 }
486 else
487 {
488 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000489 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000490 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100491
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000492 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000493 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000494 {
Derek Lambertid466a542020-01-22 15:37:29 +0000495 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000496 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000497 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000498 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000499 DataType::QSymmS8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000500 DataType::QAsymmS8,
Derek Lambertid466a542020-01-22 15:37:29 +0000501 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000502 };
Derek Lambertid466a542020-01-22 15:37:29 +0000503 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000504
505 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000506 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000507 }
508 else
509 {
510 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000511 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000512
513 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000514 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000515 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100516
517 if (biases.has_value())
518 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000519 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000520 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000521 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000522 DataType::Float32,
523 DataType::Float16,
524 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100525 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000526
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100527 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000528 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100529 }
Jan Eilers8eb25602020-03-09 12:13:48 +0000530 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100531
532 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100533}
534
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000535bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
536 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000537 Optional<std::string&> reasonIfUnsupported) const
538{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100539 bool supported = true;
540
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000541 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100542 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000543 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000544 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100545 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000546 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000547 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000548 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000549 DataType::QSymmS16,
550 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100551 };
552
553 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000554 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100555
556 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000557 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100558
559 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000560 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100561
562 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000563}
564
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100565bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
566 const TensorInfo& output,
567 const DepthToSpaceDescriptor& descriptor,
568 Optional<std::string&> reasonIfUnsupported) const
569{
Jan Eilers8eb25602020-03-09 12:13:48 +0000570 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100571 bool supported = true;
572
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000573 std::array<DataType,5> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100574 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000575 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100576 DataType::Float32,
577 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000578 DataType::QAsymmU8,
579 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100580 };
581
582 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
583 "Reference DepthToSpace: input type not supported");
584
585 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
586 "Reference DepthToSpace: output type not supported");
587
588 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
589 "Reference DepthToSpace: input and output types are mismatched");
590
591 return supported;
592}
593
arovir011c7c81b2018-10-08 11:34:28 +0100594bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
595 const TensorInfo& output,
596 const DepthwiseConvolution2dDescriptor& descriptor,
597 const TensorInfo& weights,
598 const Optional<TensorInfo>& biases,
599 Optional<std::string&> reasonIfUnsupported) const
600{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100601 bool supported = true;
602
603 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000604 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100605 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000606 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100607 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100608 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000609 DataType::QSymmS8,
610 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000611 DataType::QAsymmU8,
612 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100613 };
614
615 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
616 "Reference DepthwiseConvolution2d: input is not a supported type.");
617
618 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
619 "Reference DepthwiseConvolution2d: output is not a supported type.");
620
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100621 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
622 "Reference DepthwiseConvolution2d: input and output types mismatched.");
623
Derek Lambertid466a542020-01-22 15:37:29 +0000624 ARMNN_NO_DEPRECATE_WARN_BEGIN
625 std::array<DataType, 3> supportedWeightTypes =
626 {
627 DataType::QAsymmU8,
628 DataType::QSymmS8,
629 DataType::QuantizedSymm8PerAxis // deprecated
630 };
631 ARMNN_NO_DEPRECATE_WARN_END
632
Teresa Charlind8df0262019-11-11 12:28:15 +0000633 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000634 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000635 {
Teresa Charlind8df0262019-11-11 12:28:15 +0000636
637 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
638 "Reference convolution2d: weights type not supported for quantized input.");
639 }
640 else
641 {
642 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
643 "Reference DepthwiseConvolution2d: weights is not a supported type.");
644
645 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
646 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
647 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100648
649 if (biases.has_value())
650 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000651 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100652 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000653 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100654 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100655 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100656 DataType::Signed32
657 };
658 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
659 "Reference DepthwiseConvolution2d: biases is not a supported type.");
660 }
Jan Eilers8eb25602020-03-09 12:13:48 +0000661 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100662
663 return supported;
664
arovir011c7c81b2018-10-08 11:34:28 +0100665}
666
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000667bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
668 const TensorInfo& output,
669 Optional<std::string&> reasonIfUnsupported) const
670{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100671 bool supported = true;
672
Ryan OShea9add1202020-02-07 10:06:33 +0000673 std::array<DataType,4> supportedInputTypes = {
674 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000675 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000676 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000677 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100678 };
679
680 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000681 "Reference for Dequantize layer: input type not supported.");
682
683 supported &= CheckSupportRule( TypeNotPerAxisQuantized(input), reasonIfUnsupported,
684 "Reference for Dequantize layer: per-axis quantized input not support .");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100685
Derek Lambertid466a542020-01-22 15:37:29 +0000686 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
687 "Reference dequantize: per-axis quantized input not support .");
688
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000689 std::array<DataType,3> supportedOutputTypes = {
690 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +0000691 DataType::Float32,
692 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100693 };
694
695 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000696 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100697
698 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000699 "Reference for Dequantize layer: input/output shapes have different num total "
700 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100701
702 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000703}
704
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000705bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
706 const TensorInfo& scores,
707 const TensorInfo& anchors,
708 const TensorInfo& detectionBoxes,
709 const TensorInfo& detectionClasses,
710 const TensorInfo& detectionScores,
711 const TensorInfo& numDetections,
712 const DetectionPostProcessDescriptor& descriptor,
713 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000714{
Jan Eilers8eb25602020-03-09 12:13:48 +0000715 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +0000716
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100717 bool supported = true;
718
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000719 std::array<DataType,4> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100720 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000721 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100722 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000723 DataType::QAsymmU8,
724 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100725 };
726
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000727 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100728 "Reference DetectionPostProcess: input 0 is not a supported type.");
729
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000730 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100731 "Reference DetectionPostProcess: input 1 is not a supported type.");
732
733 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000734}
735
Pablo Tellof0bd6832019-04-26 17:58:13 +0100736bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
737 const TensorInfo& output,
738 const DepthwiseConvolution2dDescriptor& descriptor,
739 const TensorInfo& weights,
740 const Optional<TensorInfo>& biases,
741 Optional<std::string&> reasonIfUnsupported) const
742{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100743 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100744}
745
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100746bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100747 const TensorInfo& input1,
748 const TensorInfo& output,
749 Optional<std::string&> reasonIfUnsupported) const
750{
Sadik Armagan2999a022019-04-09 14:20:12 +0100751 bool supported = true;
752
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000753 std::array<DataType,5> supportedTypes = {
754 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100755 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100756 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000757 DataType::QAsymmU8,
758 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +0100759 };
760
761 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
762 "Reference division: input 0 is not a supported type.");
763
764 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
765 "Reference division: input 1 is not a supported type.");
766
767 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
768 "Reference division: output is not a supported type.");
769
770 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
771 "Reference division: input 0 and Input 1 types are mismatched");
772
773 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
774 "Reference division: input and output types are mismatched");
775
776 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
777 "Reference division: shapes are not suitable for implicit broadcast.");
778
779 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100780}
781
josh minor4a3c6102020-01-06 16:40:46 -0600782bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
783 const TensorInfo& output,
784 const ElementwiseUnaryDescriptor& descriptor,
785 Optional<std::string&> reasonIfUnsupported) const
786{
Jan Eilers8eb25602020-03-09 12:13:48 +0000787 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -0600788
Sadik Armaganac472102020-03-24 09:54:36 +0000789 std::array<DataType, 6> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -0600790 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000791 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -0600792 DataType::Float32,
793 DataType::Float16,
794 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +0000795 DataType::QSymmS16,
796 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -0600797 };
798
799 bool supported = true;
800
801 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
802 "Reference elementwise unary: input type not supported");
803
804 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
805 "Reference elementwise unary: output type not supported");
806
807 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
808 "Reference elementwise unary: input and output types not matching");
809
810 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
811 "Reference elementwise unary: input and output shapes"
812 "have different number of total elements");
813
814 return supported;
815}
816
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000817bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
818 const TensorInfo& input1,
819 const TensorInfo& output,
820 Optional<std::string&> reasonIfUnsupported) const
821{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100822 return IsComparisonSupported(input0,
823 input1,
824 output,
825 ComparisonDescriptor(ComparisonOperation::Equal),
826 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000827}
828
arovir011c7c81b2018-10-08 11:34:28 +0100829bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
830 const FakeQuantizationDescriptor& descriptor,
831 Optional<std::string&> reasonIfUnsupported) const
832{
Jan Eilers8eb25602020-03-09 12:13:48 +0000833 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100834 bool supported = true;
835
836 std::array<DataType,1> supportedTypes =
837 {
838 DataType::Float32
839 };
840
841 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
842 "Reference fake quantization: input type not supported.");
843
844 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100845}
846
847bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
848 const TensorInfo& output,
849 Optional<std::string&> reasonIfUnsupported) const
850{
Jan Eilers8eb25602020-03-09 12:13:48 +0000851 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +0100852 bool supported = true;
853
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000854 std::array<DataType,4> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100855 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000856 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +0100857 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100858 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000859 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +0100860 };
861
862 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
863 "Reference Floor: input type not supported.");
864
865 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
866 "Reference Floor: output type not supported.");
867
868 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100869}
870
871bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
872 const TensorInfo& output,
873 const TensorInfo& weights,
874 const TensorInfo& biases,
875 const FullyConnectedDescriptor& descriptor,
876 Optional<std::string&> reasonIfUnsupported) const
877{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100878 bool supported = true;
879
880 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000881 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100882 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000883 DataType::BFloat16,
884 DataType::Float32,
885 DataType::Float16,
886 DataType::QAsymmU8,
887 DataType::QAsymmS8,
888 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100889 };
890
891 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
892 "Reference Fully Connected: input type not supported.");
893
894 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
895 "Reference Fully Connected: output type not supported.");
896
Francis Murtagh46c09d02019-05-28 08:15:28 +0100897 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
898 "Reference Fully Connected: weights type not supported.");
899
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000900 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
901 if (input.GetDataType() == DataType::BFloat16)
902 {
903 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
904 {
905 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
906 supported = false;
907 }
908 }
909 else
910 {
911 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
912 "Reference Fully Connected: input and output types mismatched.");
913 }
914
Francis Murtaghddb1d062020-03-10 13:51:45 +0000915 ARMNN_NO_DEPRECATE_WARN_BEGIN
916 std::array<DataType, 3> supportedWeightTypes =
917 {
918 DataType::QAsymmU8,
919 DataType::QSymmS8,
920 DataType::QuantizedSymm8PerAxis // deprecated
921 };
922 ARMNN_NO_DEPRECATE_WARN_END
923
924 if (IsQuantized8BitType(input.GetDataType()))
925 {
926
927 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
928 "Reference Fully Connected: weights type not supported for quantized input.");
929 }
930 else
931 {
932 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
933 "Reference Fully Connected: weights is not a supported type.");
934
935 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
936 "Reference Fully Connected: input and weights types mismatched.");
937 }
Francis Murtagh46c09d02019-05-28 08:15:28 +0100938
939 if (descriptor.m_BiasEnabled)
940 {
941 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +0100942 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100943 supportedBiasTypes =
944 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000945 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100946 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100947 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +0100948 DataType::Signed32,
949 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +0100950 };
951
952 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
953 "Reference Fully Connected: bias type not supported.");
954
955 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
956 "Reference Fully Connected: bias and weight types mismatch.");
957
958 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
959 "Reference Fully Connected: bias type inferred from weights is incompatible.");
960
961 }
962
963 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100964}
965
narpra014951d842019-01-18 16:53:53 +0000966bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
967 const armnn::TensorInfo& input1,
968 const armnn::TensorInfo& output,
969 armnn::Optional<std::string&> reasonIfUnsupported) const
970{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100971 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000972 std::array<DataType,5> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100973 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000974 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100975 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100976 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000977 DataType::QAsymmU8,
978 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100979 };
980
981 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
982 "Reference Gather: input type not supported");
983
984 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
985 "Reference Gather: output type not supported");
986
987 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
988 "Reference Gather: indices (input1) type not supported");
989
990 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
991 "Reference Gather: input and output types not matching");
992
993 return supported;
narpra014951d842019-01-18 16:53:53 +0000994}
995
FrancisMurtagh878f0232018-12-19 10:56:15 +0000996bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
997 const TensorInfo& input1,
998 const TensorInfo& output,
999 Optional<std::string&> reasonIfUnsupported) const
1000{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001001 return IsComparisonSupported(input0,
1002 input1,
1003 output,
1004 ComparisonDescriptor(ComparisonOperation::Greater),
1005 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +00001006}
1007
Derek Lamberti901ea112019-12-10 22:07:09 +00001008bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1009 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001010{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001011 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001012}
1013
Kevin May09ca49c2019-10-09 12:37:34 +01001014bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1015 const TensorInfo& output,
1016 const InstanceNormalizationDescriptor& descriptor,
1017 Optional<std::string&> reasonIfUnsupported) const
1018{
Jan Eilers8eb25602020-03-09 12:13:48 +00001019 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001020 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001021 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001022 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001023 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001024 DataType::Float32,
1025 DataType::Float16
1026 };
1027
1028 bool supported = true;
1029
1030 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1031 "Reference Instance Normalization: input type not supported.");
1032
1033 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1034 "Reference Instance Normalization: output type not supported.");
1035
1036 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1037 "Reference Instance Normalization: input and output types mismatched.");
1038
1039 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1040 "Reference Instance Normalization: input and output shapes have different "
1041 "num total elements.");
1042
1043 return supported;
1044}
1045
arovir011c7c81b2018-10-08 11:34:28 +01001046bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1047 const TensorInfo& output,
1048 const L2NormalizationDescriptor& descriptor,
1049 Optional<std::string&> reasonIfUnsupported) const
1050{
Jan Eilers8eb25602020-03-09 12:13:48 +00001051 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001052 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001053 std::array<DataType, 5> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001054 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001055 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001056 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001057 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001058 DataType::QAsymmU8,
1059 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001060 };
1061
1062 bool supported = true;
1063
1064 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1065 "Reference L2normalization: input type not supported.");
1066
1067 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1068 "Reference L2normalization: output type not supported.");
1069
1070 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1071 "Reference L2normalization: input and output types mismatched.");
1072
1073 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1074 "Reference L2normalization: input and output shapes have different "
1075 "num total elements.");
1076
1077 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001078}
1079
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001080bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1081 const TensorInfo& output,
1082 const LogSoftmaxDescriptor& descriptor,
1083 Optional<std::string&> reasonIfUnsupported) const
1084{
Jan Eilers8eb25602020-03-09 12:13:48 +00001085 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001086
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001087 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001088 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001089 DataType::BFloat16,
1090 DataType::Float32,
1091 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001092 };
1093
1094 bool supported = true;
1095 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1096 "Reference LogSoftmax: input type not supported");
1097
1098 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1099 "Reference LogSoftmax: output type not supported");
1100
1101 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1102 "Reference LogSoftmax: input and output types do not match");
1103
1104 return supported;
1105}
1106
arovir011c7c81b2018-10-08 11:34:28 +01001107bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1108 const TensorInfo& outputStateIn,
1109 const TensorInfo& cellStateIn,
1110 const TensorInfo& scratchBuffer,
1111 const TensorInfo& outputStateOut,
1112 const TensorInfo& cellStateOut,
1113 const TensorInfo& output,
1114 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001115 const LstmInputParamsInfo& paramsInfo,
1116 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001117{
Jan Eilers8eb25602020-03-09 12:13:48 +00001118 IgnoreUnused(descriptor);
1119 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001120
1121 bool supported = true;
1122
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001123 std::array<DataType,3> supportedTypes = {
1124 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001125 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001126 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001127 };
1128
Jan Eilersd01a83c2019-07-03 18:20:40 +01001129 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001130 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1131 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001132 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1133 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001134 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1135 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001136 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1137 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001138 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1139 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001140 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1141 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001142 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1143 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001144 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001145 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001146 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001147 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001148 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001149 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001150 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001151 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001152 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001153 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001154 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001155 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001156 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001157 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001158 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001159 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001160 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001161 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001162 "Reference Lstm: input and OutputGateBias types are mismatched");
1163 if (!descriptor.m_CifgEnabled)
1164 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001165 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001166 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001167 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001168 reasonIfUnsupported,
1169 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001170 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001171 "Reference Lstm: input and InputGateBias types are mismatched");
1172 if (descriptor.m_PeepholeEnabled)
1173 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001174 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001175 reasonIfUnsupported,
1176 "Reference Lstm: input and CellToInputWeights types are mismatched");
1177 }
1178 }
1179 if (descriptor.m_PeepholeEnabled)
1180 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001181 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001182 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001183 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001184 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1185 }
1186 if (descriptor.m_ProjectionEnabled)
1187 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001188 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001189 "Reference Lstm: input and mProjectionWeights types are mismatched");
1190 if (paramsInfo.m_ProjectionBias != nullptr)
1191 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001192 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001193 "Reference Lstm: input and ProjectionBias types are mismatched");
1194 }
1195 }
1196 if (descriptor.m_LayerNormEnabled)
1197 {
1198 if (!descriptor.m_CifgEnabled)
1199 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001200 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001201 reasonIfUnsupported,
1202 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1203 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001204 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001205 reasonIfUnsupported,
1206 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001207 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001208 reasonIfUnsupported,
1209 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001210 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001211 reasonIfUnsupported,
1212 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1213 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001214
1215 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001216}
1217
saoste012df12b32018-11-28 16:57:20 +00001218bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1219 const TensorInfo& input1,
1220 const TensorInfo& output,
1221 Optional<std::string&> reasonIfUnsupported) const
1222{
Sadik Armagan2999a022019-04-09 14:20:12 +01001223 bool supported = true;
1224
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001225 std::array<DataType,6> supportedTypes = {
1226 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001227 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001228 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001229 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001230 DataType::QAsymmU8,
1231 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001232 };
1233
1234 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1235 "Reference maximum: input 0 is not a supported type.");
1236
1237 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1238 "Reference maximum: input 1 is not a supported type.");
1239
1240 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1241 "Reference maximum: output is not a supported type.");
1242
1243 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1244 "Reference maximum: input 0 and Input 1 types are mismatched");
1245
1246 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1247 "Reference maximum: input and output types are mismatched");
1248
1249 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1250 "Reference maximum: shapes are not suitable for implicit broadcast.");
1251
1252 return supported;
saoste012df12b32018-11-28 16:57:20 +00001253}
1254
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001255bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1256 const TensorInfo& output,
1257 const MeanDescriptor& descriptor,
1258 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001259{
James Conroy4d1ff582019-06-10 17:06:39 +01001260 bool supported = true;
1261 std::string meanLayerStr = "Mean";
1262 std::string outputTensorStr = "output";
1263
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001264 std::array<DataType,5> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001265 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001266 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001267 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001268 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001269 DataType::QAsymmU8,
1270 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001271 };
1272
1273 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1274 "Reference Mean: input type not supported.");
1275
1276 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1277 "Reference Mean: input and output types are mismatched");
1278
1279 if (descriptor.m_KeepDims)
1280 {
1281 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1282 reasonIfUnsupported,
1283 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1284 output.GetNumDimensions(),
1285 meanLayerStr, outputTensorStr).data());
1286 }
1287 else if (descriptor.m_Axis.empty())
1288 {
1289 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1290 reasonIfUnsupported,
1291 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1292 meanLayerStr, outputTensorStr).data());
1293 }
1294 else
1295 {
1296 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1297
1298 if (outputDim > 0)
1299 {
1300 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1301 reasonIfUnsupported,
1302 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1303 meanLayerStr, outputTensorStr).data());
1304 }
1305 else
1306 {
1307 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1308 reasonIfUnsupported,
1309 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1310 meanLayerStr, outputTensorStr).data());
1311 }
1312 }
1313
1314 return supported;
narpra0132b90462018-09-13 11:07:48 +01001315}
1316
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001317bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001318 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001319 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001320 Optional<std::string&> reasonIfUnsupported) const
1321{
Jim Flynne242f2d2019-05-22 14:24:13 +01001322 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001323}
1324
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001325bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1326 const TensorInfo &output,
1327 Optional<std::string &> reasonIfUnsupported) const
1328{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001329 bool supported = true;
1330
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001331 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001332 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001333 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001334 DataType::Float32,
1335 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001336 DataType::QAsymmU8,
1337 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001338 DataType::Boolean
1339 };
1340
1341 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1342 "Reference MemCopy: input type not supported");
1343
1344 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1345 "Reference MemCopy: output type not supported");
1346
1347 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1348 "Reference MemCopy: input and output types are mismatched");
1349
1350 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001351}
1352
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001353bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1354 const TensorInfo& input1,
1355 const TensorInfo& output,
1356 Optional<std::string&> reasonIfUnsupported) const
1357{
Sadik Armagan2999a022019-04-09 14:20:12 +01001358 bool supported = true;
1359
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001360 std::array<DataType,5> supportedTypes = {
1361 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001362 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001363 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001364 DataType::QAsymmU8,
1365 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001366 };
1367
1368 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1369 "Reference minimum: input 0 is not a supported type.");
1370
1371 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1372 "Reference minimum: input 1 is not a supported type.");
1373
1374 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1375 "Reference minimum: output is not a supported type.");
1376
1377 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1378 "Reference minimum: input 0 and Input 1 types are mismatched");
1379
1380 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1381 "Reference minimum: input and output types are mismatched");
1382
1383 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1384 "Reference minimum: shapes are not suitable for implicit broadcast.");
1385
1386 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001387}
1388
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001389bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1390 const TensorInfo& input1,
1391 const TensorInfo& output,
1392 Optional<std::string&> reasonIfUnsupported) const
1393{
Sadik Armagan2999a022019-04-09 14:20:12 +01001394 bool supported = true;
1395
Keith Davis67e6c542020-02-19 10:08:33 +00001396 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001397 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001398 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001399 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001400 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001401 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001402 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001403 };
1404
1405 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1406 "Reference multiplication: input 0 is not a supported type.");
1407
1408 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1409 "Reference multiplication: input 1 is not a supported type.");
1410
1411 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1412 "Reference multiplication: output is not a supported type.");
1413
1414 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1415 "Reference multiplication: input 0 and Input 1 types are mismatched");
1416
1417 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1418 "Reference multiplication: input and output types are mismatched");
1419
1420 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1421 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1422
1423 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001424}
1425
1426bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1427 const TensorInfo& output,
1428 const NormalizationDescriptor& descriptor,
1429 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001430{
Jan Eilers8eb25602020-03-09 12:13:48 +00001431 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001432
1433 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001434 std::array<DataType, 5> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001435 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001436 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001437 DataType::Float16,
1438 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001439 DataType::QAsymmU8,
1440 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001441 };
1442
1443 bool supported = true;
1444
1445 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1446 "Reference normalization: input type not supported.");
1447
1448 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1449 "Reference normalization: output type not supported.");
1450
1451 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1452 "Reference normalization: input and output shapes have different "
1453 "num total elements.");
1454
1455 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001456}
1457
Derek Lamberti901ea112019-12-10 22:07:09 +00001458bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1459 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001460{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001461 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001462}
1463
1464bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1465 const TensorInfo& output,
1466 const PadDescriptor& descriptor,
1467 Optional<std::string&> reasonIfUnsupported) const
1468{
Jan Eilers8eb25602020-03-09 12:13:48 +00001469 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001470 bool supported = true;
1471
1472 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001473 std::array<DataType,5> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001474 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001475 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001476 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001477 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001478 DataType::QAsymmU8,
1479 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001480 };
1481
1482 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1483 "Reference pad: input is not a supported type.");
1484
1485 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1486 "Reference pad: output is not a supported type.");
1487
1488 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1489 "Reference pad: input and output types are mismatched.");
1490
1491 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001492}
1493
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001494bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1495 const TensorInfo& output,
1496 const PermuteDescriptor& descriptor,
1497 Optional<std::string&> reasonIfUnsupported) const
1498{
Jan Eilers8eb25602020-03-09 12:13:48 +00001499 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001500 bool supported = true;
1501
1502 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001503 std::array<DataType, 5> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001504 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001505 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001506 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001507 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001508 DataType::QAsymmU8,
1509 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001510 };
1511
1512 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1513 "Reference permute: input is not a supported type.");
1514
1515 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1516 "Reference permute: output is not a supported type.");
1517
1518 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1519 "Reference permute: input and output types are mismatched.");
1520
1521 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001522}
1523
1524bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1525 const TensorInfo& output,
1526 const Pooling2dDescriptor& descriptor,
1527 Optional<std::string&> reasonIfUnsupported) const
1528{
Jan Eilers8eb25602020-03-09 12:13:48 +00001529 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001530 bool supported = true;
1531
1532 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001533 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001534 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001535 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001536 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001537 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001538 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001539 DataType::QAsymmU8,
1540 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001541 };
1542
1543 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1544 "Reference poolind2d: input is not a supported type.");
1545
1546 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1547 "Reference poolind2d: output is not a supported type.");
1548
1549 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1550 "Reference poolind2d: input and output types are mismatched.");
1551
1552 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001553}
1554
Derek Lamberti5f400d62019-03-25 15:41:58 +00001555bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1556 const TensorInfo& output,
1557 Optional<std::string&> reasonIfUnsupported) const
1558{
1559 bool supported = true;
1560
Finn Williamsfd271062019-12-04 14:27:27 +00001561 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001562 std::array<DataType,7> supportedInputTypes = {
1563 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00001564 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001565 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001566 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001567 DataType::QAsymmU8,
1568 DataType::QSymmS8,
1569 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001570 };
1571
1572 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1573 "Reference quantize: input type not supported.");
1574
1575 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001576 std::array<DataType,4> supportedOutputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001577 DataType::QAsymmU8,
Ryan OShea9add1202020-02-07 10:06:33 +00001578 DataType::QAsymmS8,
Finn Williamsfd271062019-12-04 14:27:27 +00001579 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001580 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001581 };
1582 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1583 "Reference quantize: output type not supported.");
1584
1585 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1586 "Reference quantize: input and output shapes have different num total elements.");
1587
1588 return supported;
1589}
1590
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001591bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001592 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001593 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001594 Optional<std::string&> reasonIfUnsupported) const
1595{
Jan Eilers8eb25602020-03-09 12:13:48 +00001596 IgnoreUnused(output);
1597 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001598 // Define supported output types.
Keith Davis0c2eeac2020-02-11 16:51:50 +00001599 std::array<DataType,7> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001600 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001601 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001602 DataType::Float32,
1603 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001604 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001605 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001606 DataType::QAsymmU8,
1607 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001608 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001609
Nina Drozd2f2778f2019-05-27 10:37:05 +01001610 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1611 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001612}
1613
1614bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001615 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001616 Optional<std::string&> reasonIfUnsupported) const
1617{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001618 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001619 std::array<DataType,5> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001620 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001621 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001622 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001623 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001624 DataType::QAsymmU8,
1625 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001626 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001627
1628 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1629 "Reference ResizeBilinear: input type not supported");
1630
1631 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1632 "Reference ResizeBilinear: output type not supported");
1633
1634 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1635 "Reference ResizeBilinear: input and output types not matching");
1636
1637 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001638}
1639
Teresa Charlin970f43b2019-07-01 13:51:07 +01001640bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1641 const TensorInfo& output,
1642 const ResizeDescriptor& descriptor,
1643 Optional<std::string&> reasonIfUnsupported) const
1644{
Jan Eilers8eb25602020-03-09 12:13:48 +00001645 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001646 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001647 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001648 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001649 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001650 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001651 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001652 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001653 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001654 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001655 };
1656
1657 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1658 "Reference Resize: input type not supported");
1659
1660 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1661 "Reference Resize: output type not supported");
1662
1663 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1664 "Reference Resize: input and output types not matching");
1665
1666 return supported;
1667}
1668
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001669bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1670 const TensorInfo& output,
1671 Optional<std::string&> reasonIfUnsupported) const
1672{
josh minor4a3c6102020-01-06 16:40:46 -06001673 return IsElementwiseUnarySupported(input,
1674 output,
1675 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1676 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001677}
1678
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001679bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1680 const TensorInfo& output,
1681 const SliceDescriptor& descriptor,
1682 Optional<std::string&> reasonIfUnsupported) const
1683{
Jan Eilers8eb25602020-03-09 12:13:48 +00001684 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001685 bool supported = true;
1686
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001687 std::array<DataType, 4> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001688 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001689 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001690 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001691 DataType::QAsymmU8,
1692 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001693 };
1694
1695 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1696 "Reference Slice: input type not supported");
1697
1698 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1699 "Reference Slice: output type not supported");
1700
1701 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1702 "Reference Slice: input and output types are mismatched");
1703
1704 return supported;
1705}
1706
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001707bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1708 const TensorInfo& output,
1709 const SoftmaxDescriptor& descriptor,
1710 Optional<std::string&> reasonIfUnsupported) const
1711{
Jan Eilers8eb25602020-03-09 12:13:48 +00001712 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001713 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001714 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001715 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001716 DataType::BFloat16,
1717 DataType::Float32,
1718 DataType::Float16,
1719 DataType::QSymmS8,
1720 DataType::QAsymmS8,
1721 DataType::QAsymmU8,
1722 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001723 };
1724
1725 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001726 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001727
1728 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001729 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001730
1731 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001732 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001733
1734 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001735}
1736
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001737bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1738 const TensorInfo& output,
1739 const SpaceToBatchNdDescriptor& descriptor,
1740 Optional<std::string&> reasonIfUnsupported) const
1741{
Jan Eilers8eb25602020-03-09 12:13:48 +00001742 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001743 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001744 std::array<DataType,5> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001745 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001746 DataType::BFloat16,
1747 DataType::Float32,
1748 DataType::Float16,
1749 DataType::QAsymmU8,
1750 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001751 };
1752
1753 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1754 "Reference SpaceToBatchNd: input type not supported");
1755
1756 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1757 "Reference SpaceToBatchNd: output type not supported");
1758
1759 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1760 "Reference SpaceToBatchNd: input and output types are mismatched");
1761
1762 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001763}
1764
Keith Davisa57eccb2019-06-14 17:33:22 +01001765bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001766 const TensorInfo& output,
1767 const SpaceToDepthDescriptor& descriptor,
1768 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001769{
1770
Jan Eilers8eb25602020-03-09 12:13:48 +00001771 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01001772 bool supported = true;
1773
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001774 std::array<DataType,5> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001775 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001776 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001777 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001778 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001779 DataType::QAsymmU8,
1780 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001781 };
1782
1783 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1784 "Reference SpaceToDepth: input type not supported");
1785
1786 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1787 "Reference SpaceToDepth: output type not supported");
1788
1789 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1790 "Reference SpaceToDepth: input and output types are mismatched");
1791
1792 return supported;
1793}
1794
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001795bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1796 const ViewsDescriptor& descriptor,
1797 Optional<std::string&> reasonIfUnsupported) const
1798{
Jan Eilers8eb25602020-03-09 12:13:48 +00001799 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001800 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001801 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001802 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001803 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001804 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001805 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001806 DataType::QAsymmU8,
1807 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001808 };
1809
1810 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1811 "Reference splitter: input type not supported");
1812
1813 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001814}
1815
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001816bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1817 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1818 const ViewsDescriptor& descriptor,
1819 Optional<std::string&> reasonIfUnsupported) const
1820{
Jan Eilers8eb25602020-03-09 12:13:48 +00001821 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001822 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001823 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001824 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001825 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001826 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001827 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001828 DataType::QAsymmU8,
1829 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001830 };
1831
1832 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1833 "Reference splitter: output type not supported");
1834 for (const TensorInfo output : outputs)
1835 {
1836 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1837 "Reference splitter: input type not supported");
1838
1839 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1840 "Reference splitter: input and output types mismatched.");
1841 }
1842
1843 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001844}
1845
Matthew Jackson81e601c2019-07-11 12:07:09 +01001846bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1847 const TensorInfo& output,
1848 const StackDescriptor& descriptor,
1849 Optional<std::string&> reasonIfUnsupported) const
1850{
Jan Eilers8eb25602020-03-09 12:13:48 +00001851 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001852
1853 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001854 std::array<DataType,5> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001855 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001856 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001857 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001858 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001859 DataType::QAsymmU8,
1860 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01001861 };
1862
1863 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1864 "Reference stack: output type not supported");
1865 for (const TensorInfo* input : inputs)
1866 {
1867 BOOST_ASSERT(input != nullptr);
1868 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1869 "Reference stack: input type not supported");
1870
1871 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1872 "Reference stack: input and output types mismatched.");
1873 }
1874
1875 return supported;
1876}
1877
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001878bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1879 const TensorInfo& output,
1880 const StridedSliceDescriptor& descriptor,
1881 Optional<std::string&> reasonIfUnsupported) const
1882{
Jan Eilers8eb25602020-03-09 12:13:48 +00001883 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001884 bool supported = true;
1885
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001886 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001887 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001888 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001889 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001890 DataType::QAsymmU8,
1891 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001892 };
1893
1894 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1895 "Reference StridedSlice: input type not supported");
1896
1897 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1898 "Reference StridedSlice: output type not supported");
1899
1900 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1901 "Reference StridedSlice: input and output types are mismatched");
1902
1903 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001904}
1905
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001906bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1907 const TensorInfo& input1,
1908 const TensorInfo& output,
1909 Optional<std::string&> reasonIfUnsupported) const
1910{
Sadik Armagan2999a022019-04-09 14:20:12 +01001911 bool supported = true;
1912
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001913 std::array<DataType,5> supportedTypes = {
1914 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001915 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001916 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001917 DataType::QAsymmU8,
1918 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001919 };
1920
1921 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1922 "Reference subtraction: input 0 is not a supported type.");
1923
1924 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1925 "Reference subtraction: input 1 is not a supported type.");
1926
1927 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1928 "Reference subtraction: output is not a supported type.");
1929
1930 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1931 "Reference subtraction: input 0 and Input 1 types are mismatched");
1932
1933 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1934 "Reference subtraction: input and output types are mismatched");
1935
1936 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1937 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1938
1939 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001940}
1941
Matteo Martincighab9e5252019-06-13 17:27:46 +01001942bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1943 const TensorInfo& alpha,
1944 const TensorInfo& output,
1945 Optional<std::string&> reasonIfUnsupported) const
1946{
1947 bool supported = true;
1948
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001949 std::array<DataType, 5> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001950 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001951 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01001952 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001953 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001954 DataType::QAsymmU8,
1955 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01001956 };
1957
1958 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1959 "PReLU: input is not a supported type.");
1960
1961 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1962 "PReLU: alpha is not a supported type.");
1963
1964 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1965 "PReLU: output is not a supported type.");
1966
1967 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1968 "PReLU: input, alpha and output types are mismatched");
1969
1970 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1971 "PReLU: shapes are not suitable for implicit broadcast");
1972
1973 return supported;
1974}
1975
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001976bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1977 const TensorInfo& output,
1978 const TransposeConvolution2dDescriptor& descriptor,
1979 const TensorInfo& weights,
1980 const Optional<TensorInfo>& biases,
1981 Optional<std::string&> reasonIfUnsupported) const
1982{
Jan Eilers8eb25602020-03-09 12:13:48 +00001983 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001984 bool supported = true;
1985
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001986 std::array<DataType,5> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001987 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001988 DataType::BFloat16,
1989 DataType::Float32,
1990 DataType::Float16,
1991 DataType::QAsymmU8,
1992 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001993 };
1994
1995 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1996 "Reference TransposeConvolution2d: input is not a supported type.");
1997
1998 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1999 "Reference TransposeConvolution2d: output is not a supported type.");
2000
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002001 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2002 "Reference TransposeConvolution2d: input and output types mismatched.");
2003
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002004
2005 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +00002006 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002007 {
Derek Lambertid466a542020-01-22 15:37:29 +00002008 ARMNN_NO_DEPRECATE_WARN_BEGIN
2009 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002010 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002011 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00002012 DataType::QSymmS8,
2013 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002014 };
Derek Lambertid466a542020-01-22 15:37:29 +00002015 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002016
2017 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2018 "Reference TransposeConvolution2d: weights type not supported for "
2019 "quantized input.");
2020 }
2021 else
2022 {
2023 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2024 "Reference TransposeConvolution2d: weights is not a supported type.");
2025
2026 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2027 "Reference TransposeConvolution2d: input and weights types mismatched.");
2028 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002029
2030 if (biases.has_value())
2031 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002032 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002033 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002034 DataType::BFloat16,
2035 DataType::Float32,
2036 DataType::Float16,
2037 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002038 };
2039 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2040 "Reference TransposeConvolution2d: biases is not a supported type.");
2041 }
2042
2043 return supported;
2044}
2045
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002046bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2047 const TensorInfo& output,
2048 const TransposeDescriptor& descriptor,
2049 Optional<std::string&> reasonIfUnsupported) const
2050{
Jan Eilers8eb25602020-03-09 12:13:48 +00002051 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002052 bool supported = true;
2053
2054 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002055 std::array<DataType, 5> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002056 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002057 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002058 DataType::Float32,
2059 DataType::Float16,
2060 DataType::QAsymmU8,
2061 DataType::QSymmS16
2062 };
2063
2064 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2065 "Reference transpose: input is not a supported type.");
2066
2067 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2068 "Reference transpose: output is not a supported type.");
2069
2070 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2071 "Reference transpose: input and output types are mismatched.");
2072
2073 return supported;
2074}
2075
arovir011c7c81b2018-10-08 11:34:28 +01002076} // namespace armnn