blob: a4f4efd92aa7f144224220d1c394c11293c9620a [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matteo Martincighe011d202019-11-28 11:35:47 +000013#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010014#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000015
Matteo Martincighe011d202019-11-28 11:35:47 +000016#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <array>
20
telsoa014fcda012018-03-09 14:13:49 +000021using namespace boost;
22
23namespace armnn
24{
25
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010026namespace
27{
28
29template<typename Float32Func, typename Uint8Func, typename ... Params>
30bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
31 DataType dataType,
32 Float32Func floatFuncPtr,
33 Uint8Func uint8FuncPtr,
34 Params&&... params)
35{
36 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
37 dataType,
38 &FalseFunc<Params...>,
39 floatFuncPtr,
40 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000041 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000042 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010043 std::forward<Params>(params)...);
44}
45
46} // anonymous namespace
47
James Conroy4d1ff582019-06-10 17:06:39 +010048namespace
49{
50
51std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
52 unsigned int actual,
53 std::string& layerStr,
54 std::string& tensorName)
55{
56 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
57 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
58
59 return errorMsg;
60}
61
62} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000063
Sadik Armagan9199e582019-09-05 17:35:31 +010064bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
65 Optional<std::string&> reasonIfUnsupported) const
66{
josh minor4a3c6102020-01-06 16:40:46 -060067 return IsElementwiseUnarySupported(input,
68 output,
69 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
70 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010071}
72
arovir011c7c81b2018-10-08 11:34:28 +010073bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
74 const TensorInfo& output,
75 const ActivationDescriptor& descriptor,
76 Optional<std::string&> reasonIfUnsupported) const
77{
Derek Lamberti50db4e82019-03-13 14:16:15 +000078 bool supported = true;
79
80 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000081 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +000082 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +000083 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010084 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000085 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000086 DataType::QAsymmU8,
87 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000088 };
89
90 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
91 "Reference activation: input type not supported.");
92
93 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
94 "Reference activation: output type not supported.");
95
96 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
97 "Reference activation: input and output types mismatched.");
98
99 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
100 "Reference activation: input and output shapes are of different rank.");
101
102
103 struct ActivationFunctionSupported : public Rule
104 {
105 ActivationFunctionSupported(const ActivationDescriptor& desc)
106 {
107 switch(desc.m_Function)
108 {
109 case ActivationFunction::Abs:
110 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000111 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000112 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000113 case ActivationFunction::LeakyReLu:
114 case ActivationFunction::Linear:
115 case ActivationFunction::ReLu:
116 case ActivationFunction::Sigmoid:
117 case ActivationFunction::SoftReLu:
118 case ActivationFunction::Sqrt:
119 case ActivationFunction::Square:
120 case ActivationFunction::TanH:
121 {
122 m_Res = true;
123 break;
124 }
125 default:
126 {
127 m_Res = false;
128 break;
129 }
130 }
131 }
132 };
133
134 // Function is supported
135 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
136 "Reference activation: function not supported.");
137
138 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100139}
140
141bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
142 const TensorInfo& input1,
143 const TensorInfo& output,
144 Optional<std::string&> reasonIfUnsupported) const
145{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000146 bool supported = true;
147
Keith Davis0c2eeac2020-02-11 16:51:50 +0000148 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000149 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000150 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100151 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000152 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000153 DataType::QAsymmU8,
154 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000155 };
156
157 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
158 "Reference addition: input 0 is not a supported type.");
159
160 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
161 "Reference addition: input 1 is not a supported type.");
162
163 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
164 "Reference addition: output is not a supported type.");
165
166 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
167 "Reference addition: input 0 and Input 1 types are mismatched");
168
169 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
170 "Reference addition: input and output types are mismatched");
171
172 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
173 "Reference addition: shapes are not suitable for implicit broadcast.");
174
175 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100176}
177
Nikhil Raj68c2c902019-09-19 11:21:11 +0100178bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
179 const armnn::ArgMinMaxDescriptor &descriptor,
180 armnn::Optional<std::string &> reasonIfUnsupported) const
181{
Jan Eilers8eb25602020-03-09 12:13:48 +0000182 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100183
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000184 std::array<DataType, 5> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100185 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000186 DataType::BFloat16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100187 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000188 DataType::QAsymmU8,
189 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000190 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100191 };
192
193 bool supported = true;
194
195 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
196 "Reference ArgMinMax: input is not a supported type.");
197 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
198 "Reference ArgMinMax: output type not supported");
199
200 return supported;
201}
202
arovir011c7c81b2018-10-08 11:34:28 +0100203bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
204 const TensorInfo& output,
205 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100206 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100207 const TensorInfo& beta,
208 const TensorInfo& gamma,
209 const BatchNormalizationDescriptor& descriptor,
210 Optional<std::string&> reasonIfUnsupported) const
211{
Jan Eilers8eb25602020-03-09 12:13:48 +0000212 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100213
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000214 std::array<DataType, 5> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100215 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000216 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100217 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100218 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000219 DataType::QAsymmU8,
220 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100221 };
222
223 bool supported = true;
224
225 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
226 "Reference batch normalization: input is not a supported type.");
227
228 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
229 "Reference batch normalization: output is not a supported type.");
230
231 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
232 "Reference batch normalization: input and output types are mismatched");
233
234 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
235 "Reference batch normalization: mean is not a supported type.");
236
237 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
238 "Reference batch normalization: variance is not a supported type.");
239
240 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
241 "Reference batch normalization: beta is not a supported type.");
242
243 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
244 "Reference batch normalization: gamma is not a supported type.");
245
246 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100247}
248
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000249bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
250 const TensorInfo& output,
251 const BatchToSpaceNdDescriptor& descriptor,
252 Optional<std::string&> reasonIfUnsupported) const
253{
Jan Eilers8eb25602020-03-09 12:13:48 +0000254 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100255
256 bool supported = true;
257
258 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
259 std::string inputTensorStr = "input";
260 std::string outputTensorStr = "output";
261
262 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000263 std::array<DataType,5> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100264 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000265 DataType::BFloat16,
266 DataType::Float32,
267 DataType::Float16,
268 DataType::QAsymmU8,
269 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100270 };
271
272 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
273 "Reference BatchToSpaceNd: input type not supported.");
274
275 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
276 "Reference BatchToSpaceNd: output type not supported.");
277
278 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
279 "Reference BatchToSpaceNd: input and output types mismatched.");
280
281 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
282 reasonIfUnsupported,
283 CreateIncorrectDimensionsErrorMsg(4,
284 output.GetNumDimensions(),
285 batchToSpaceNdLayerStr,
286 outputTensorStr).data());
287
288 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
289 reasonIfUnsupported,
290 CreateIncorrectDimensionsErrorMsg(4,
291 input.GetNumDimensions(),
292 batchToSpaceNdLayerStr,
293 inputTensorStr).data());
294
295 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000296}
297
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100298bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
299 const TensorInfo& input1,
300 const TensorInfo& output,
301 const ComparisonDescriptor& descriptor,
302 Optional<std::string&> reasonIfUnsupported) const
303{
Jan Eilers8eb25602020-03-09 12:13:48 +0000304 IgnoreUnused(descriptor);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100305
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000306 std::array<DataType, 5> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100307 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000308 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100309 DataType::Float32,
310 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000311 DataType::QAsymmU8,
312 DataType::QSymmS16
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100313 };
314
315 bool supported = true;
316 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
317 "Reference comparison: input 0 is not a supported type");
318
319 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
320 "Reference comparison: input 0 and Input 1 types are mismatched");
321
322 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
323 "Reference comparison: output is not of type Boolean");
324
325 return supported;
326}
327
Jim Flynn906f9462019-05-10 13:55:21 +0100328bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
329 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100330 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100331 Optional<std::string&> reasonIfUnsupported) const
332{
Jan Eilers8eb25602020-03-09 12:13:48 +0000333 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100334
335 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000336 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100337 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000338 DataType::BFloat16,
339 DataType::Float32,
340 DataType::Float16,
341 DataType::QAsymmU8,
342 DataType::QAsymmS8,
343 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100344 };
345
346 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
347 "Reference concatenation: output type not supported");
348 for (const TensorInfo* input : inputs)
349 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100350 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100351 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
352 "Reference concatenation: input type not supported");
353
354 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
355 "Reference concatenation: input and output types mismatched.");
356 }
357
358 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100359}
360
arovir011c7c81b2018-10-08 11:34:28 +0100361bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
362 Optional<std::string&> reasonIfUnsupported) const
363{
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000364 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100365 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000366 DataType::BFloat16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100367 DataType::Float32,
368 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000369 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +0000370 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000371 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000372 DataType::QSymmS16
Nina Drozd58ef2c62019-05-16 12:09:18 +0100373 };
374
375 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
376 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100377}
378
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000379bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
380 const TensorInfo& output,
381 Optional<std::string&> reasonIfUnsupported) const
382{
383 bool supported = true;
384
385 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
386 "Reference for ConvertBf16ToFp32 layer: input type not supported");
387
388 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
389 "Reference for ConvertBf16ToFp32 layer: output type not supported");
390
391 return supported;
392}
393
arovir011c7c81b2018-10-08 11:34:28 +0100394bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
395 const TensorInfo& output,
396 Optional<std::string&> reasonIfUnsupported) const
397{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100398 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
399 input.GetDataType(),
400 &TrueFunc<>,
401 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000402 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000403 &FalseFuncI32<>,
404 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100405 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
406 output.GetDataType(),
407 &FalseOutputFuncF16<>,
408 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000409 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000410 &FalseFuncI32<>,
411 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100412}
413
414bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
415 const TensorInfo& output,
416 Optional<std::string&> reasonIfUnsupported) const
417{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100418 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
419 input.GetDataType(),
420 &FalseInputFuncF16<>,
421 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000422 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000423 &FalseFuncI32<>,
424 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100425 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
426 output.GetDataType(),
427 &TrueFunc<>,
428 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000429 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000430 &FalseFuncI32<>,
431 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100432}
433
434bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
435 const TensorInfo& output,
436 const Convolution2dDescriptor& descriptor,
437 const TensorInfo& weights,
438 const Optional<TensorInfo>& biases,
439 Optional<std::string&> reasonIfUnsupported) const
440{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100441 bool supported = true;
442
443 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000444 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000445 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000446 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000447 DataType::Float32,
448 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000449 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000450 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000451 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000452 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100453 };
454
455 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000456 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100457
458 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000459 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100460
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100461 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000462 "Reference Convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100463
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000464 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000465 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000466 {
Derek Lambertid466a542020-01-22 15:37:29 +0000467 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000468 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000469 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000470 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000471 DataType::QSymmS8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000472 DataType::QAsymmS8,
Derek Lambertid466a542020-01-22 15:37:29 +0000473 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000474 };
Derek Lambertid466a542020-01-22 15:37:29 +0000475 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000476
477 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000478 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000479 }
480 else
481 {
482 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000483 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000484
485 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000486 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000487 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100488
489 if (biases.has_value())
490 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000491 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000492 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000493 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000494 DataType::Float32,
495 DataType::Float16,
496 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100497 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000498
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100499 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000500 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100501 }
Jan Eilers8eb25602020-03-09 12:13:48 +0000502 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100503
504 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100505}
506
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000507bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
508 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000509 Optional<std::string&> reasonIfUnsupported) const
510{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100511 bool supported = true;
512
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000513 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100514 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000515 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000516 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100517 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000518 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000519 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000520 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000521 DataType::QSymmS16,
522 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100523 };
524
525 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000526 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100527
528 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000529 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100530
531 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000532 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100533
534 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000535}
536
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100537bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
538 const TensorInfo& output,
539 const DepthToSpaceDescriptor& descriptor,
540 Optional<std::string&> reasonIfUnsupported) const
541{
Jan Eilers8eb25602020-03-09 12:13:48 +0000542 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100543 bool supported = true;
544
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000545 std::array<DataType,5> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100546 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000547 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100548 DataType::Float32,
549 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000550 DataType::QAsymmU8,
551 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100552 };
553
554 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
555 "Reference DepthToSpace: input type not supported");
556
557 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
558 "Reference DepthToSpace: output type not supported");
559
560 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
561 "Reference DepthToSpace: input and output types are mismatched");
562
563 return supported;
564}
565
arovir011c7c81b2018-10-08 11:34:28 +0100566bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
567 const TensorInfo& output,
568 const DepthwiseConvolution2dDescriptor& descriptor,
569 const TensorInfo& weights,
570 const Optional<TensorInfo>& biases,
571 Optional<std::string&> reasonIfUnsupported) const
572{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100573 bool supported = true;
574
575 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000576 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100577 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000578 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100579 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100580 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000581 DataType::QSymmS8,
582 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000583 DataType::QAsymmU8,
584 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100585 };
586
587 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
588 "Reference DepthwiseConvolution2d: input is not a supported type.");
589
590 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
591 "Reference DepthwiseConvolution2d: output is not a supported type.");
592
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100593 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
594 "Reference DepthwiseConvolution2d: input and output types mismatched.");
595
Derek Lambertid466a542020-01-22 15:37:29 +0000596 ARMNN_NO_DEPRECATE_WARN_BEGIN
597 std::array<DataType, 3> supportedWeightTypes =
598 {
599 DataType::QAsymmU8,
600 DataType::QSymmS8,
601 DataType::QuantizedSymm8PerAxis // deprecated
602 };
603 ARMNN_NO_DEPRECATE_WARN_END
604
Teresa Charlind8df0262019-11-11 12:28:15 +0000605 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000606 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000607 {
Teresa Charlind8df0262019-11-11 12:28:15 +0000608
609 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
610 "Reference convolution2d: weights type not supported for quantized input.");
611 }
612 else
613 {
614 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
615 "Reference DepthwiseConvolution2d: weights is not a supported type.");
616
617 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
618 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
619 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100620
621 if (biases.has_value())
622 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000623 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100624 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000625 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100626 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100627 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100628 DataType::Signed32
629 };
630 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
631 "Reference DepthwiseConvolution2d: biases is not a supported type.");
632 }
Jan Eilers8eb25602020-03-09 12:13:48 +0000633 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100634
635 return supported;
636
arovir011c7c81b2018-10-08 11:34:28 +0100637}
638
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000639bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
640 const TensorInfo& output,
641 Optional<std::string&> reasonIfUnsupported) const
642{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100643 bool supported = true;
644
Ryan OShea9add1202020-02-07 10:06:33 +0000645 std::array<DataType,4> supportedInputTypes = {
646 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000647 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000648 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000649 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100650 };
651
652 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000653 "Reference for Dequantize layer: input type not supported.");
654
655 supported &= CheckSupportRule( TypeNotPerAxisQuantized(input), reasonIfUnsupported,
656 "Reference for Dequantize layer: per-axis quantized input not support .");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100657
Derek Lambertid466a542020-01-22 15:37:29 +0000658 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
659 "Reference dequantize: per-axis quantized input not support .");
660
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000661 std::array<DataType,3> supportedOutputTypes = {
662 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +0000663 DataType::Float32,
664 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100665 };
666
667 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000668 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100669
670 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000671 "Reference for Dequantize layer: input/output shapes have different num total "
672 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100673
674 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000675}
676
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000677bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
678 const TensorInfo& scores,
679 const TensorInfo& anchors,
680 const TensorInfo& detectionBoxes,
681 const TensorInfo& detectionClasses,
682 const TensorInfo& detectionScores,
683 const TensorInfo& numDetections,
684 const DetectionPostProcessDescriptor& descriptor,
685 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000686{
Jan Eilers8eb25602020-03-09 12:13:48 +0000687 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +0000688
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100689 bool supported = true;
690
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000691 std::array<DataType,4> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100692 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000693 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100694 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000695 DataType::QAsymmU8,
696 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100697 };
698
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000699 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100700 "Reference DetectionPostProcess: input 0 is not a supported type.");
701
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000702 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100703 "Reference DetectionPostProcess: input 1 is not a supported type.");
704
705 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000706}
707
Pablo Tellof0bd6832019-04-26 17:58:13 +0100708bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
709 const TensorInfo& output,
710 const DepthwiseConvolution2dDescriptor& descriptor,
711 const TensorInfo& weights,
712 const Optional<TensorInfo>& biases,
713 Optional<std::string&> reasonIfUnsupported) const
714{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100715 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100716}
717
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100718bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100719 const TensorInfo& input1,
720 const TensorInfo& output,
721 Optional<std::string&> reasonIfUnsupported) const
722{
Sadik Armagan2999a022019-04-09 14:20:12 +0100723 bool supported = true;
724
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000725 std::array<DataType,5> supportedTypes = {
726 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100727 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100728 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000729 DataType::QAsymmU8,
730 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +0100731 };
732
733 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
734 "Reference division: input 0 is not a supported type.");
735
736 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
737 "Reference division: input 1 is not a supported type.");
738
739 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
740 "Reference division: output is not a supported type.");
741
742 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
743 "Reference division: input 0 and Input 1 types are mismatched");
744
745 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
746 "Reference division: input and output types are mismatched");
747
748 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
749 "Reference division: shapes are not suitable for implicit broadcast.");
750
751 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100752}
753
josh minor4a3c6102020-01-06 16:40:46 -0600754bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
755 const TensorInfo& output,
756 const ElementwiseUnaryDescriptor& descriptor,
757 Optional<std::string&> reasonIfUnsupported) const
758{
Jan Eilers8eb25602020-03-09 12:13:48 +0000759 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -0600760
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000761 std::array<DataType, 5> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -0600762 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000763 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -0600764 DataType::Float32,
765 DataType::Float16,
766 DataType::QAsymmU8,
767 DataType::QSymmS16
768 };
769
770 bool supported = true;
771
772 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
773 "Reference elementwise unary: input type not supported");
774
775 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
776 "Reference elementwise unary: output type not supported");
777
778 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
779 "Reference elementwise unary: input and output types not matching");
780
781 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
782 "Reference elementwise unary: input and output shapes"
783 "have different number of total elements");
784
785 return supported;
786}
787
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000788bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
789 const TensorInfo& input1,
790 const TensorInfo& output,
791 Optional<std::string&> reasonIfUnsupported) const
792{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100793 return IsComparisonSupported(input0,
794 input1,
795 output,
796 ComparisonDescriptor(ComparisonOperation::Equal),
797 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000798}
799
arovir011c7c81b2018-10-08 11:34:28 +0100800bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
801 const FakeQuantizationDescriptor& descriptor,
802 Optional<std::string&> reasonIfUnsupported) const
803{
Jan Eilers8eb25602020-03-09 12:13:48 +0000804 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100805 bool supported = true;
806
807 std::array<DataType,1> supportedTypes =
808 {
809 DataType::Float32
810 };
811
812 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
813 "Reference fake quantization: input type not supported.");
814
815 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100816}
817
818bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
819 const TensorInfo& output,
820 Optional<std::string&> reasonIfUnsupported) const
821{
Jan Eilers8eb25602020-03-09 12:13:48 +0000822 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +0100823 bool supported = true;
824
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000825 std::array<DataType,4> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100826 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000827 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +0100828 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100829 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000830 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +0100831 };
832
833 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
834 "Reference Floor: input type not supported.");
835
836 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
837 "Reference Floor: output type not supported.");
838
839 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100840}
841
842bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
843 const TensorInfo& output,
844 const TensorInfo& weights,
845 const TensorInfo& biases,
846 const FullyConnectedDescriptor& descriptor,
847 Optional<std::string&> reasonIfUnsupported) const
848{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100849 bool supported = true;
850
851 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000852 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100853 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000854 DataType::BFloat16,
855 DataType::Float32,
856 DataType::Float16,
857 DataType::QAsymmU8,
858 DataType::QAsymmS8,
859 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100860 };
861
862 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
863 "Reference Fully Connected: input type not supported.");
864
865 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
866 "Reference Fully Connected: output type not supported.");
867
868 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
869 "Reference Fully Connected: input and output types mismatched.");
870
871 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
872 "Reference Fully Connected: weights type not supported.");
873
Francis Murtaghddb1d062020-03-10 13:51:45 +0000874 ARMNN_NO_DEPRECATE_WARN_BEGIN
875 std::array<DataType, 3> supportedWeightTypes =
876 {
877 DataType::QAsymmU8,
878 DataType::QSymmS8,
879 DataType::QuantizedSymm8PerAxis // deprecated
880 };
881 ARMNN_NO_DEPRECATE_WARN_END
882
883 if (IsQuantized8BitType(input.GetDataType()))
884 {
885
886 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
887 "Reference Fully Connected: weights type not supported for quantized input.");
888 }
889 else
890 {
891 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
892 "Reference Fully Connected: weights is not a supported type.");
893
894 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
895 "Reference Fully Connected: input and weights types mismatched.");
896 }
Francis Murtagh46c09d02019-05-28 08:15:28 +0100897
898 if (descriptor.m_BiasEnabled)
899 {
900 // Defined supported types for bias
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000901 std::array<DataType, 4>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100902 supportedBiasTypes =
903 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000904 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100905 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100906 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100907 DataType::Signed32
908 };
909
910 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
911 "Reference Fully Connected: bias type not supported.");
912
913 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
914 "Reference Fully Connected: bias and weight types mismatch.");
915
916 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
917 "Reference Fully Connected: bias type inferred from weights is incompatible.");
918
919 }
920
921 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100922}
923
narpra014951d842019-01-18 16:53:53 +0000924bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
925 const armnn::TensorInfo& input1,
926 const armnn::TensorInfo& output,
927 armnn::Optional<std::string&> reasonIfUnsupported) const
928{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100929 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000930 std::array<DataType,5> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100931 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000932 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100933 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100934 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000935 DataType::QAsymmU8,
936 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100937 };
938
939 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
940 "Reference Gather: input type not supported");
941
942 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
943 "Reference Gather: output type not supported");
944
945 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
946 "Reference Gather: indices (input1) type not supported");
947
948 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
949 "Reference Gather: input and output types not matching");
950
951 return supported;
narpra014951d842019-01-18 16:53:53 +0000952}
953
FrancisMurtagh878f0232018-12-19 10:56:15 +0000954bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
955 const TensorInfo& input1,
956 const TensorInfo& output,
957 Optional<std::string&> reasonIfUnsupported) const
958{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100959 return IsComparisonSupported(input0,
960 input1,
961 output,
962 ComparisonDescriptor(ComparisonOperation::Greater),
963 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000964}
965
Derek Lamberti901ea112019-12-10 22:07:09 +0000966bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
967 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +0100968{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100969 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100970}
971
Kevin May09ca49c2019-10-09 12:37:34 +0100972bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
973 const TensorInfo& output,
974 const InstanceNormalizationDescriptor& descriptor,
975 Optional<std::string&> reasonIfUnsupported) const
976{
Jan Eilers8eb25602020-03-09 12:13:48 +0000977 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +0100978 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000979 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +0100980 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000981 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +0100982 DataType::Float32,
983 DataType::Float16
984 };
985
986 bool supported = true;
987
988 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
989 "Reference Instance Normalization: input type not supported.");
990
991 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
992 "Reference Instance Normalization: output type not supported.");
993
994 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
995 "Reference Instance Normalization: input and output types mismatched.");
996
997 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
998 "Reference Instance Normalization: input and output shapes have different "
999 "num total elements.");
1000
1001 return supported;
1002}
1003
arovir011c7c81b2018-10-08 11:34:28 +01001004bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1005 const TensorInfo& output,
1006 const L2NormalizationDescriptor& descriptor,
1007 Optional<std::string&> reasonIfUnsupported) const
1008{
Jan Eilers8eb25602020-03-09 12:13:48 +00001009 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001010 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001011 std::array<DataType, 5> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001012 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001013 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001014 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001015 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001016 DataType::QAsymmU8,
1017 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001018 };
1019
1020 bool supported = true;
1021
1022 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1023 "Reference L2normalization: input type not supported.");
1024
1025 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1026 "Reference L2normalization: output type not supported.");
1027
1028 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1029 "Reference L2normalization: input and output types mismatched.");
1030
1031 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1032 "Reference L2normalization: input and output shapes have different "
1033 "num total elements.");
1034
1035 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001036}
1037
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001038bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1039 const TensorInfo& output,
1040 const LogSoftmaxDescriptor& descriptor,
1041 Optional<std::string&> reasonIfUnsupported) const
1042{
Jan Eilers8eb25602020-03-09 12:13:48 +00001043 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001044
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001045 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001046 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001047 DataType::BFloat16,
1048 DataType::Float32,
1049 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001050 };
1051
1052 bool supported = true;
1053 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1054 "Reference LogSoftmax: input type not supported");
1055
1056 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1057 "Reference LogSoftmax: output type not supported");
1058
1059 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1060 "Reference LogSoftmax: input and output types do not match");
1061
1062 return supported;
1063}
1064
arovir011c7c81b2018-10-08 11:34:28 +01001065bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1066 const TensorInfo& outputStateIn,
1067 const TensorInfo& cellStateIn,
1068 const TensorInfo& scratchBuffer,
1069 const TensorInfo& outputStateOut,
1070 const TensorInfo& cellStateOut,
1071 const TensorInfo& output,
1072 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001073 const LstmInputParamsInfo& paramsInfo,
1074 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001075{
Jan Eilers8eb25602020-03-09 12:13:48 +00001076 IgnoreUnused(descriptor);
1077 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001078
1079 bool supported = true;
1080
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001081 std::array<DataType,3> supportedTypes = {
1082 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001083 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001084 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001085 };
1086
Jan Eilersd01a83c2019-07-03 18:20:40 +01001087 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001088 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1089 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001090 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1091 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001092 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1093 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001094 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1095 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001096 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1097 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001098 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1099 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001100 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1101 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001102 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001103 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001104 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001105 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001106 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001107 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001108 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001109 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001110 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001111 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001112 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001113 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001114 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001115 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001116 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001117 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001118 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001119 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001120 "Reference Lstm: input and OutputGateBias types are mismatched");
1121 if (!descriptor.m_CifgEnabled)
1122 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001123 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001124 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001125 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001126 reasonIfUnsupported,
1127 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001128 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001129 "Reference Lstm: input and InputGateBias types are mismatched");
1130 if (descriptor.m_PeepholeEnabled)
1131 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001132 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001133 reasonIfUnsupported,
1134 "Reference Lstm: input and CellToInputWeights types are mismatched");
1135 }
1136 }
1137 if (descriptor.m_PeepholeEnabled)
1138 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001139 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001140 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001141 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001142 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1143 }
1144 if (descriptor.m_ProjectionEnabled)
1145 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001146 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001147 "Reference Lstm: input and mProjectionWeights types are mismatched");
1148 if (paramsInfo.m_ProjectionBias != nullptr)
1149 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001150 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001151 "Reference Lstm: input and ProjectionBias types are mismatched");
1152 }
1153 }
1154 if (descriptor.m_LayerNormEnabled)
1155 {
1156 if (!descriptor.m_CifgEnabled)
1157 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001158 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001159 reasonIfUnsupported,
1160 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1161 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001162 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001163 reasonIfUnsupported,
1164 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001165 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001166 reasonIfUnsupported,
1167 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001168 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001169 reasonIfUnsupported,
1170 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1171 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001172
1173 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001174}
1175
saoste012df12b32018-11-28 16:57:20 +00001176bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1177 const TensorInfo& input1,
1178 const TensorInfo& output,
1179 Optional<std::string&> reasonIfUnsupported) const
1180{
Sadik Armagan2999a022019-04-09 14:20:12 +01001181 bool supported = true;
1182
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001183 std::array<DataType,6> supportedTypes = {
1184 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001185 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001186 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001187 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001188 DataType::QAsymmU8,
1189 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001190 };
1191
1192 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1193 "Reference maximum: input 0 is not a supported type.");
1194
1195 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1196 "Reference maximum: input 1 is not a supported type.");
1197
1198 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1199 "Reference maximum: output is not a supported type.");
1200
1201 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1202 "Reference maximum: input 0 and Input 1 types are mismatched");
1203
1204 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1205 "Reference maximum: input and output types are mismatched");
1206
1207 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1208 "Reference maximum: shapes are not suitable for implicit broadcast.");
1209
1210 return supported;
saoste012df12b32018-11-28 16:57:20 +00001211}
1212
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001213bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1214 const TensorInfo& output,
1215 const MeanDescriptor& descriptor,
1216 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001217{
James Conroy4d1ff582019-06-10 17:06:39 +01001218 bool supported = true;
1219 std::string meanLayerStr = "Mean";
1220 std::string outputTensorStr = "output";
1221
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001222 std::array<DataType,5> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001223 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001224 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001225 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001226 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001227 DataType::QAsymmU8,
1228 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001229 };
1230
1231 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1232 "Reference Mean: input type not supported.");
1233
1234 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1235 "Reference Mean: input and output types are mismatched");
1236
1237 if (descriptor.m_KeepDims)
1238 {
1239 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1240 reasonIfUnsupported,
1241 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1242 output.GetNumDimensions(),
1243 meanLayerStr, outputTensorStr).data());
1244 }
1245 else if (descriptor.m_Axis.empty())
1246 {
1247 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1248 reasonIfUnsupported,
1249 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1250 meanLayerStr, outputTensorStr).data());
1251 }
1252 else
1253 {
1254 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1255
1256 if (outputDim > 0)
1257 {
1258 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1259 reasonIfUnsupported,
1260 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1261 meanLayerStr, outputTensorStr).data());
1262 }
1263 else
1264 {
1265 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1266 reasonIfUnsupported,
1267 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1268 meanLayerStr, outputTensorStr).data());
1269 }
1270 }
1271
1272 return supported;
narpra0132b90462018-09-13 11:07:48 +01001273}
1274
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001275bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001276 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001277 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001278 Optional<std::string&> reasonIfUnsupported) const
1279{
Jim Flynne242f2d2019-05-22 14:24:13 +01001280 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001281}
1282
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001283bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1284 const TensorInfo &output,
1285 Optional<std::string &> reasonIfUnsupported) const
1286{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001287 bool supported = true;
1288
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001289 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001290 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001291 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001292 DataType::Float32,
1293 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001294 DataType::QAsymmU8,
1295 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001296 DataType::Boolean
1297 };
1298
1299 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1300 "Reference MemCopy: input type not supported");
1301
1302 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1303 "Reference MemCopy: output type not supported");
1304
1305 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1306 "Reference MemCopy: input and output types are mismatched");
1307
1308 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001309}
1310
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001311bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1312 const TensorInfo& input1,
1313 const TensorInfo& output,
1314 Optional<std::string&> reasonIfUnsupported) const
1315{
Sadik Armagan2999a022019-04-09 14:20:12 +01001316 bool supported = true;
1317
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001318 std::array<DataType,5> supportedTypes = {
1319 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001320 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001321 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001322 DataType::QAsymmU8,
1323 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001324 };
1325
1326 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1327 "Reference minimum: input 0 is not a supported type.");
1328
1329 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1330 "Reference minimum: input 1 is not a supported type.");
1331
1332 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1333 "Reference minimum: output is not a supported type.");
1334
1335 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1336 "Reference minimum: input 0 and Input 1 types are mismatched");
1337
1338 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1339 "Reference minimum: input and output types are mismatched");
1340
1341 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1342 "Reference minimum: shapes are not suitable for implicit broadcast.");
1343
1344 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001345}
1346
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001347bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1348 const TensorInfo& input1,
1349 const TensorInfo& output,
1350 Optional<std::string&> reasonIfUnsupported) const
1351{
Sadik Armagan2999a022019-04-09 14:20:12 +01001352 bool supported = true;
1353
Keith Davis67e6c542020-02-19 10:08:33 +00001354 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001355 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001356 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001357 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001358 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001359 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001360 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001361 };
1362
1363 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1364 "Reference multiplication: input 0 is not a supported type.");
1365
1366 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1367 "Reference multiplication: input 1 is not a supported type.");
1368
1369 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1370 "Reference multiplication: output is not a supported type.");
1371
1372 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1373 "Reference multiplication: input 0 and Input 1 types are mismatched");
1374
1375 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1376 "Reference multiplication: input and output types are mismatched");
1377
1378 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1379 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1380
1381 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001382}
1383
1384bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1385 const TensorInfo& output,
1386 const NormalizationDescriptor& descriptor,
1387 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001388{
Jan Eilers8eb25602020-03-09 12:13:48 +00001389 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001390
1391 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001392 std::array<DataType, 5> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001393 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001394 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001395 DataType::Float16,
1396 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001397 DataType::QAsymmU8,
1398 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001399 };
1400
1401 bool supported = true;
1402
1403 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1404 "Reference normalization: input type not supported.");
1405
1406 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1407 "Reference normalization: output type not supported.");
1408
1409 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1410 "Reference normalization: input and output shapes have different "
1411 "num total elements.");
1412
1413 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001414}
1415
Derek Lamberti901ea112019-12-10 22:07:09 +00001416bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1417 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001418{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001419 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001420}
1421
1422bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1423 const TensorInfo& output,
1424 const PadDescriptor& descriptor,
1425 Optional<std::string&> reasonIfUnsupported) const
1426{
Jan Eilers8eb25602020-03-09 12:13:48 +00001427 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001428 bool supported = true;
1429
1430 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001431 std::array<DataType,5> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001432 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001433 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001434 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001435 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001436 DataType::QAsymmU8,
1437 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001438 };
1439
1440 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1441 "Reference pad: input is not a supported type.");
1442
1443 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1444 "Reference pad: output is not a supported type.");
1445
1446 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1447 "Reference pad: input and output types are mismatched.");
1448
1449 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001450}
1451
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001452bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1453 const TensorInfo& output,
1454 const PermuteDescriptor& descriptor,
1455 Optional<std::string&> reasonIfUnsupported) const
1456{
Jan Eilers8eb25602020-03-09 12:13:48 +00001457 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001458 bool supported = true;
1459
1460 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001461 std::array<DataType, 5> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001462 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001463 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001464 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001465 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001466 DataType::QAsymmU8,
1467 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001468 };
1469
1470 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1471 "Reference permute: input is not a supported type.");
1472
1473 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1474 "Reference permute: output is not a supported type.");
1475
1476 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1477 "Reference permute: input and output types are mismatched.");
1478
1479 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001480}
1481
1482bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1483 const TensorInfo& output,
1484 const Pooling2dDescriptor& descriptor,
1485 Optional<std::string&> reasonIfUnsupported) const
1486{
Jan Eilers8eb25602020-03-09 12:13:48 +00001487 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001488 bool supported = true;
1489
1490 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001491 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001492 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001493 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001494 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001495 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001496 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001497 DataType::QAsymmU8,
1498 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001499 };
1500
1501 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1502 "Reference poolind2d: input is not a supported type.");
1503
1504 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1505 "Reference poolind2d: output is not a supported type.");
1506
1507 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1508 "Reference poolind2d: input and output types are mismatched.");
1509
1510 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001511}
1512
Derek Lamberti5f400d62019-03-25 15:41:58 +00001513bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1514 const TensorInfo& output,
1515 Optional<std::string&> reasonIfUnsupported) const
1516{
1517 bool supported = true;
1518
Finn Williamsfd271062019-12-04 14:27:27 +00001519 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001520 std::array<DataType,7> supportedInputTypes = {
1521 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00001522 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001523 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001524 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001525 DataType::QAsymmU8,
1526 DataType::QSymmS8,
1527 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001528 };
1529
1530 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1531 "Reference quantize: input type not supported.");
1532
1533 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001534 std::array<DataType,4> supportedOutputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001535 DataType::QAsymmU8,
Ryan OShea9add1202020-02-07 10:06:33 +00001536 DataType::QAsymmS8,
Finn Williamsfd271062019-12-04 14:27:27 +00001537 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001538 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001539 };
1540 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1541 "Reference quantize: output type not supported.");
1542
1543 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1544 "Reference quantize: input and output shapes have different num total elements.");
1545
1546 return supported;
1547}
1548
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001549bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001550 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001551 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001552 Optional<std::string&> reasonIfUnsupported) const
1553{
Jan Eilers8eb25602020-03-09 12:13:48 +00001554 IgnoreUnused(output);
1555 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001556 // Define supported output types.
Keith Davis0c2eeac2020-02-11 16:51:50 +00001557 std::array<DataType,7> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001558 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001559 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001560 DataType::Float32,
1561 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001562 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001563 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001564 DataType::QAsymmU8,
1565 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001566 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001567
Nina Drozd2f2778f2019-05-27 10:37:05 +01001568 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1569 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001570}
1571
1572bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001573 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001574 Optional<std::string&> reasonIfUnsupported) const
1575{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001576 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001577 std::array<DataType,5> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001578 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001579 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001580 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001581 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001582 DataType::QAsymmU8,
1583 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001584 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001585
1586 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1587 "Reference ResizeBilinear: input type not supported");
1588
1589 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1590 "Reference ResizeBilinear: output type not supported");
1591
1592 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1593 "Reference ResizeBilinear: input and output types not matching");
1594
1595 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001596}
1597
Teresa Charlin970f43b2019-07-01 13:51:07 +01001598bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1599 const TensorInfo& output,
1600 const ResizeDescriptor& descriptor,
1601 Optional<std::string&> reasonIfUnsupported) const
1602{
Jan Eilers8eb25602020-03-09 12:13:48 +00001603 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001604 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001605 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001606 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001607 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001608 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001609 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001610 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001611 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001612 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001613 };
1614
1615 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1616 "Reference Resize: input type not supported");
1617
1618 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1619 "Reference Resize: output type not supported");
1620
1621 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1622 "Reference Resize: input and output types not matching");
1623
1624 return supported;
1625}
1626
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001627bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1628 const TensorInfo& output,
1629 Optional<std::string&> reasonIfUnsupported) const
1630{
josh minor4a3c6102020-01-06 16:40:46 -06001631 return IsElementwiseUnarySupported(input,
1632 output,
1633 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1634 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001635}
1636
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001637bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1638 const TensorInfo& output,
1639 const SliceDescriptor& descriptor,
1640 Optional<std::string&> reasonIfUnsupported) const
1641{
Jan Eilers8eb25602020-03-09 12:13:48 +00001642 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001643 bool supported = true;
1644
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001645 std::array<DataType, 4> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001646 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001647 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001648 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001649 DataType::QAsymmU8,
1650 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001651 };
1652
1653 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1654 "Reference Slice: input type not supported");
1655
1656 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1657 "Reference Slice: output type not supported");
1658
1659 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1660 "Reference Slice: input and output types are mismatched");
1661
1662 return supported;
1663}
1664
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001665bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1666 const TensorInfo& output,
1667 const SoftmaxDescriptor& descriptor,
1668 Optional<std::string&> reasonIfUnsupported) const
1669{
Jan Eilers8eb25602020-03-09 12:13:48 +00001670 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001671 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001672 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001673 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001674 DataType::BFloat16,
1675 DataType::Float32,
1676 DataType::Float16,
1677 DataType::QSymmS8,
1678 DataType::QAsymmS8,
1679 DataType::QAsymmU8,
1680 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001681 };
1682
1683 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001684 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001685
1686 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001687 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001688
1689 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001690 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001691
1692 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001693}
1694
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001695bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1696 const TensorInfo& output,
1697 const SpaceToBatchNdDescriptor& descriptor,
1698 Optional<std::string&> reasonIfUnsupported) const
1699{
Jan Eilers8eb25602020-03-09 12:13:48 +00001700 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001701 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001702 std::array<DataType,5> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001703 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001704 DataType::BFloat16,
1705 DataType::Float32,
1706 DataType::Float16,
1707 DataType::QAsymmU8,
1708 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001709 };
1710
1711 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1712 "Reference SpaceToBatchNd: input type not supported");
1713
1714 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1715 "Reference SpaceToBatchNd: output type not supported");
1716
1717 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1718 "Reference SpaceToBatchNd: input and output types are mismatched");
1719
1720 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001721}
1722
Keith Davisa57eccb2019-06-14 17:33:22 +01001723bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001724 const TensorInfo& output,
1725 const SpaceToDepthDescriptor& descriptor,
1726 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001727{
1728
Jan Eilers8eb25602020-03-09 12:13:48 +00001729 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01001730 bool supported = true;
1731
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001732 std::array<DataType,5> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001733 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001734 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001735 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001736 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001737 DataType::QAsymmU8,
1738 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001739 };
1740
1741 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1742 "Reference SpaceToDepth: input type not supported");
1743
1744 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1745 "Reference SpaceToDepth: output type not supported");
1746
1747 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1748 "Reference SpaceToDepth: input and output types are mismatched");
1749
1750 return supported;
1751}
1752
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001753bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1754 const ViewsDescriptor& descriptor,
1755 Optional<std::string&> reasonIfUnsupported) const
1756{
Jan Eilers8eb25602020-03-09 12:13:48 +00001757 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001758 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001759 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001760 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001761 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001762 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001763 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001764 DataType::QAsymmU8,
1765 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001766 };
1767
1768 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1769 "Reference splitter: input type not supported");
1770
1771 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001772}
1773
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001774bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1775 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1776 const ViewsDescriptor& descriptor,
1777 Optional<std::string&> reasonIfUnsupported) const
1778{
Jan Eilers8eb25602020-03-09 12:13:48 +00001779 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001780 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001781 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001782 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001783 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001784 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001785 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001786 DataType::QAsymmU8,
1787 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001788 };
1789
1790 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1791 "Reference splitter: output type not supported");
1792 for (const TensorInfo output : outputs)
1793 {
1794 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1795 "Reference splitter: input type not supported");
1796
1797 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1798 "Reference splitter: input and output types mismatched.");
1799 }
1800
1801 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001802}
1803
Matthew Jackson81e601c2019-07-11 12:07:09 +01001804bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1805 const TensorInfo& output,
1806 const StackDescriptor& descriptor,
1807 Optional<std::string&> reasonIfUnsupported) const
1808{
Jan Eilers8eb25602020-03-09 12:13:48 +00001809 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001810
1811 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001812 std::array<DataType,5> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001813 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001814 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001815 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001816 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001817 DataType::QAsymmU8,
1818 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01001819 };
1820
1821 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1822 "Reference stack: output type not supported");
1823 for (const TensorInfo* input : inputs)
1824 {
1825 BOOST_ASSERT(input != nullptr);
1826 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1827 "Reference stack: input type not supported");
1828
1829 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1830 "Reference stack: input and output types mismatched.");
1831 }
1832
1833 return supported;
1834}
1835
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001836bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1837 const TensorInfo& output,
1838 const StridedSliceDescriptor& descriptor,
1839 Optional<std::string&> reasonIfUnsupported) const
1840{
Jan Eilers8eb25602020-03-09 12:13:48 +00001841 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001842 bool supported = true;
1843
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001844 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001845 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001846 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001847 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001848 DataType::QAsymmU8,
1849 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001850 };
1851
1852 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1853 "Reference StridedSlice: input type not supported");
1854
1855 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1856 "Reference StridedSlice: output type not supported");
1857
1858 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1859 "Reference StridedSlice: input and output types are mismatched");
1860
1861 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001862}
1863
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001864bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1865 const TensorInfo& input1,
1866 const TensorInfo& output,
1867 Optional<std::string&> reasonIfUnsupported) const
1868{
Sadik Armagan2999a022019-04-09 14:20:12 +01001869 bool supported = true;
1870
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001871 std::array<DataType,5> supportedTypes = {
1872 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001873 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001874 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001875 DataType::QAsymmU8,
1876 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001877 };
1878
1879 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1880 "Reference subtraction: input 0 is not a supported type.");
1881
1882 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1883 "Reference subtraction: input 1 is not a supported type.");
1884
1885 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1886 "Reference subtraction: output is not a supported type.");
1887
1888 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1889 "Reference subtraction: input 0 and Input 1 types are mismatched");
1890
1891 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1892 "Reference subtraction: input and output types are mismatched");
1893
1894 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1895 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1896
1897 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001898}
1899
Matteo Martincighab9e5252019-06-13 17:27:46 +01001900bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1901 const TensorInfo& alpha,
1902 const TensorInfo& output,
1903 Optional<std::string&> reasonIfUnsupported) const
1904{
1905 bool supported = true;
1906
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001907 std::array<DataType, 5> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001908 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001909 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01001910 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001911 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001912 DataType::QAsymmU8,
1913 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01001914 };
1915
1916 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1917 "PReLU: input is not a supported type.");
1918
1919 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1920 "PReLU: alpha is not a supported type.");
1921
1922 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1923 "PReLU: output is not a supported type.");
1924
1925 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1926 "PReLU: input, alpha and output types are mismatched");
1927
1928 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1929 "PReLU: shapes are not suitable for implicit broadcast");
1930
1931 return supported;
1932}
1933
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001934bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1935 const TensorInfo& output,
1936 const TransposeConvolution2dDescriptor& descriptor,
1937 const TensorInfo& weights,
1938 const Optional<TensorInfo>& biases,
1939 Optional<std::string&> reasonIfUnsupported) const
1940{
Jan Eilers8eb25602020-03-09 12:13:48 +00001941 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001942 bool supported = true;
1943
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001944 std::array<DataType,5> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001945 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001946 DataType::BFloat16,
1947 DataType::Float32,
1948 DataType::Float16,
1949 DataType::QAsymmU8,
1950 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001951 };
1952
1953 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1954 "Reference TransposeConvolution2d: input is not a supported type.");
1955
1956 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1957 "Reference TransposeConvolution2d: output is not a supported type.");
1958
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001959 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1960 "Reference TransposeConvolution2d: input and output types mismatched.");
1961
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001962
1963 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +00001964 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001965 {
Derek Lambertid466a542020-01-22 15:37:29 +00001966 ARMNN_NO_DEPRECATE_WARN_BEGIN
1967 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001968 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001969 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00001970 DataType::QSymmS8,
1971 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001972 };
Derek Lambertid466a542020-01-22 15:37:29 +00001973 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001974
1975 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1976 "Reference TransposeConvolution2d: weights type not supported for "
1977 "quantized input.");
1978 }
1979 else
1980 {
1981 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1982 "Reference TransposeConvolution2d: weights is not a supported type.");
1983
1984 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1985 "Reference TransposeConvolution2d: input and weights types mismatched.");
1986 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001987
1988 if (biases.has_value())
1989 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001990 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001991 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001992 DataType::BFloat16,
1993 DataType::Float32,
1994 DataType::Float16,
1995 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001996 };
1997 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1998 "Reference TransposeConvolution2d: biases is not a supported type.");
1999 }
2000
2001 return supported;
2002}
2003
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002004bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2005 const TensorInfo& output,
2006 const TransposeDescriptor& descriptor,
2007 Optional<std::string&> reasonIfUnsupported) const
2008{
Jan Eilers8eb25602020-03-09 12:13:48 +00002009 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002010 bool supported = true;
2011
2012 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002013 std::array<DataType, 5> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002014 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002015 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002016 DataType::Float32,
2017 DataType::Float16,
2018 DataType::QAsymmU8,
2019 DataType::QSymmS16
2020 };
2021
2022 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2023 "Reference transpose: input is not a supported type.");
2024
2025 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2026 "Reference transpose: output is not a supported type.");
2027
2028 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2029 "Reference transpose: input and output types are mismatched.");
2030
2031 return supported;
2032}
2033
arovir011c7c81b2018-10-08 11:34:28 +01002034} // namespace armnn