blob: 3bcb7e0261fbe4709857cf35dda9b479e8ae5cb6 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matteo Martincighe011d202019-11-28 11:35:47 +000013#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010014#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000015
Matteo Martincighe011d202019-11-28 11:35:47 +000016#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <array>
20
telsoa014fcda012018-03-09 14:13:49 +000021using namespace boost;
22
23namespace armnn
24{
25
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010026namespace
27{
28
29template<typename Float32Func, typename Uint8Func, typename ... Params>
30bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
31 DataType dataType,
32 Float32Func floatFuncPtr,
33 Uint8Func uint8FuncPtr,
34 Params&&... params)
35{
36 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
37 dataType,
38 &FalseFunc<Params...>,
39 floatFuncPtr,
40 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000041 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000042 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010043 std::forward<Params>(params)...);
44}
45
46} // anonymous namespace
47
James Conroy4d1ff582019-06-10 17:06:39 +010048namespace
49{
50
51std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
52 unsigned int actual,
53 std::string& layerStr,
54 std::string& tensorName)
55{
56 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
57 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
58
59 return errorMsg;
60}
61
62} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000063
Sadik Armagan9199e582019-09-05 17:35:31 +010064bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
65 Optional<std::string&> reasonIfUnsupported) const
66{
josh minor4a3c6102020-01-06 16:40:46 -060067 return IsElementwiseUnarySupported(input,
68 output,
69 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
70 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010071}
72
arovir011c7c81b2018-10-08 11:34:28 +010073bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
74 const TensorInfo& output,
75 const ActivationDescriptor& descriptor,
76 Optional<std::string&> reasonIfUnsupported) const
77{
Derek Lamberti50db4e82019-03-13 14:16:15 +000078 bool supported = true;
79
80 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000081 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +000082 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +000083 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010084 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000085 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000086 DataType::QAsymmU8,
87 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000088 };
89
90 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
91 "Reference activation: input type not supported.");
92
93 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
94 "Reference activation: output type not supported.");
95
96 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
97 "Reference activation: input and output types mismatched.");
98
99 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
100 "Reference activation: input and output shapes are of different rank.");
101
102
103 struct ActivationFunctionSupported : public Rule
104 {
105 ActivationFunctionSupported(const ActivationDescriptor& desc)
106 {
107 switch(desc.m_Function)
108 {
109 case ActivationFunction::Abs:
110 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000111 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000112 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000113 case ActivationFunction::LeakyReLu:
114 case ActivationFunction::Linear:
115 case ActivationFunction::ReLu:
116 case ActivationFunction::Sigmoid:
117 case ActivationFunction::SoftReLu:
118 case ActivationFunction::Sqrt:
119 case ActivationFunction::Square:
120 case ActivationFunction::TanH:
121 {
122 m_Res = true;
123 break;
124 }
125 default:
126 {
127 m_Res = false;
128 break;
129 }
130 }
131 }
132 };
133
134 // Function is supported
135 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
136 "Reference activation: function not supported.");
137
138 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100139}
140
141bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
142 const TensorInfo& input1,
143 const TensorInfo& output,
144 Optional<std::string&> reasonIfUnsupported) const
145{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000146 bool supported = true;
147
Keith Davis0c2eeac2020-02-11 16:51:50 +0000148 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000149 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000150 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100151 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000152 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000153 DataType::QAsymmU8,
154 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000155 };
156
157 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
158 "Reference addition: input 0 is not a supported type.");
159
160 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
161 "Reference addition: input 1 is not a supported type.");
162
163 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
164 "Reference addition: output is not a supported type.");
165
166 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
167 "Reference addition: input 0 and Input 1 types are mismatched");
168
169 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
170 "Reference addition: input and output types are mismatched");
171
172 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
173 "Reference addition: shapes are not suitable for implicit broadcast.");
174
175 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100176}
177
Nikhil Raj68c2c902019-09-19 11:21:11 +0100178bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
179 const armnn::ArgMinMaxDescriptor &descriptor,
180 armnn::Optional<std::string &> reasonIfUnsupported) const
181{
Jan Eilers8eb25602020-03-09 12:13:48 +0000182 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100183
Sadik Armagan303980c2020-04-17 12:45:14 +0100184 std::array<DataType, 6> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100185 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000186 DataType::BFloat16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100187 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100188 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000189 DataType::QAsymmU8,
190 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000191 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100192 };
193
194 bool supported = true;
195
196 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
197 "Reference ArgMinMax: input is not a supported type.");
198 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
199 "Reference ArgMinMax: output type not supported");
200
201 return supported;
202}
203
arovir011c7c81b2018-10-08 11:34:28 +0100204bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
205 const TensorInfo& output,
206 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100207 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100208 const TensorInfo& beta,
209 const TensorInfo& gamma,
210 const BatchNormalizationDescriptor& descriptor,
211 Optional<std::string&> reasonIfUnsupported) const
212{
Jan Eilers8eb25602020-03-09 12:13:48 +0000213 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100214
Sadik Armagan303980c2020-04-17 12:45:14 +0100215 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100216 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000217 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100218 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100219 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100220 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000221 DataType::QAsymmU8,
222 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100223 };
224
225 bool supported = true;
226
227 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
228 "Reference batch normalization: input is not a supported type.");
229
230 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
231 "Reference batch normalization: output is not a supported type.");
232
233 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
234 "Reference batch normalization: input and output types are mismatched");
235
236 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
237 "Reference batch normalization: mean is not a supported type.");
238
239 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
240 "Reference batch normalization: variance is not a supported type.");
241
242 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
243 "Reference batch normalization: beta is not a supported type.");
244
245 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
246 "Reference batch normalization: gamma is not a supported type.");
247
248 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100249}
250
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000251bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
252 const TensorInfo& output,
253 const BatchToSpaceNdDescriptor& descriptor,
254 Optional<std::string&> reasonIfUnsupported) const
255{
Jan Eilers8eb25602020-03-09 12:13:48 +0000256 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100257
258 bool supported = true;
259
260 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
261 std::string inputTensorStr = "input";
262 std::string outputTensorStr = "output";
263
264 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100265 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100266 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000267 DataType::BFloat16,
268 DataType::Float32,
269 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100270 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000271 DataType::QAsymmU8,
272 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100273 };
274
275 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
276 "Reference BatchToSpaceNd: input type not supported.");
277
278 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
279 "Reference BatchToSpaceNd: output type not supported.");
280
281 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
282 "Reference BatchToSpaceNd: input and output types mismatched.");
283
284 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
285 reasonIfUnsupported,
286 CreateIncorrectDimensionsErrorMsg(4,
287 output.GetNumDimensions(),
288 batchToSpaceNdLayerStr,
289 outputTensorStr).data());
290
291 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
292 reasonIfUnsupported,
293 CreateIncorrectDimensionsErrorMsg(4,
294 input.GetNumDimensions(),
295 batchToSpaceNdLayerStr,
296 inputTensorStr).data());
297
298 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000299}
300
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100301bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
302 const TensorInfo& input1,
303 const TensorInfo& output,
304 const ComparisonDescriptor& descriptor,
305 Optional<std::string&> reasonIfUnsupported) const
306{
Jan Eilers8eb25602020-03-09 12:13:48 +0000307 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100308 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100309 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000310 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000311 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100312 DataType::Float32,
313 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100314 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000315 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000316 DataType::QSymmS16,
317 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100318 };
319
320 bool supported = true;
321 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
322 "Reference comparison: input 0 is not a supported type");
323
324 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
325 "Reference comparison: input 0 and Input 1 types are mismatched");
326
327 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
328 "Reference comparison: output is not of type Boolean");
329
330 return supported;
331}
332
Jim Flynn906f9462019-05-10 13:55:21 +0100333bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
334 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100335 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100336 Optional<std::string&> reasonIfUnsupported) const
337{
Jan Eilers8eb25602020-03-09 12:13:48 +0000338 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100339
340 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000341 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100342 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000343 DataType::BFloat16,
344 DataType::Float32,
345 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000346 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100347 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000348 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100349 };
350
351 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
352 "Reference concatenation: output type not supported");
353 for (const TensorInfo* input : inputs)
354 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100355 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100356 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
357 "Reference concatenation: input type not supported");
358
359 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
360 "Reference concatenation: input and output types mismatched.");
361 }
362
363 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100364}
365
arovir011c7c81b2018-10-08 11:34:28 +0100366bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
367 Optional<std::string&> reasonIfUnsupported) const
368{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100369 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100370 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000371 DataType::BFloat16,
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100372 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100373 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000374 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100375 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000376 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100377 DataType::QSymmS16,
378 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100379 };
380
381 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
382 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100383}
384
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000385bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
386 const TensorInfo& output,
387 Optional<std::string&> reasonIfUnsupported) const
388{
389 bool supported = true;
390
391 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
392 "Reference for ConvertBf16ToFp32 layer: input type not supported");
393
394 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
395 "Reference for ConvertBf16ToFp32 layer: output type not supported");
396
397 return supported;
398}
399
arovir011c7c81b2018-10-08 11:34:28 +0100400bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
401 const TensorInfo& output,
402 Optional<std::string&> reasonIfUnsupported) const
403{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100404 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
405 input.GetDataType(),
406 &TrueFunc<>,
407 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000408 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000409 &FalseFuncI32<>,
410 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100411 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
412 output.GetDataType(),
413 &FalseOutputFuncF16<>,
414 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000415 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000416 &FalseFuncI32<>,
417 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100418}
419
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000420bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
421 const TensorInfo& output,
422 Optional<std::string&> reasonIfUnsupported) const
423{
424 bool supported = true;
425
426 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
427 "Reference for ConvertFp32ToBf16 layer: input type not supported");
428
429 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
430 "Reference for ConvertFp32ToBf16 layer: output type not supported");
431
432 return supported;
433}
434
arovir011c7c81b2018-10-08 11:34:28 +0100435bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
436 const TensorInfo& output,
437 Optional<std::string&> reasonIfUnsupported) const
438{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100439 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
440 input.GetDataType(),
441 &FalseInputFuncF16<>,
442 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000443 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000444 &FalseFuncI32<>,
445 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100446 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
447 output.GetDataType(),
448 &TrueFunc<>,
449 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000450 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000451 &FalseFuncI32<>,
452 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100453}
454
455bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
456 const TensorInfo& output,
457 const Convolution2dDescriptor& descriptor,
458 const TensorInfo& weights,
459 const Optional<TensorInfo>& biases,
460 Optional<std::string&> reasonIfUnsupported) const
461{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100462 bool supported = true;
463
464 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000465 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000466 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000467 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000468 DataType::Float32,
469 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000470 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100471 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000472 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000473 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100474 };
475
476 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000477 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100478
479 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000480 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100481
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000482 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
483 if (input.GetDataType() == DataType::BFloat16)
484 {
485 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
486 {
487 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
488 supported = false;
489 }
490 }
491 else
492 {
493 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000494 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000495 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100496
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000497 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000498 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000499 {
Derek Lambertid466a542020-01-22 15:37:29 +0000500 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000501 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000502 {
Sadik Armagan303980c2020-04-17 12:45:14 +0100503 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000504 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000505 DataType::QSymmS8,
506 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000507 };
Derek Lambertid466a542020-01-22 15:37:29 +0000508 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000509
510 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000511 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000512 }
513 else
514 {
515 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000516 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000517
518 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000519 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000520 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100521
522 if (biases.has_value())
523 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000524 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000525 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000526 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000527 DataType::Float32,
528 DataType::Float16,
529 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100530 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000531
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100532 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000533 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100534 }
Jan Eilers8eb25602020-03-09 12:13:48 +0000535 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100536
537 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100538}
539
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000540bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
541 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000542 Optional<std::string&> reasonIfUnsupported) const
543{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100544 bool supported = true;
545
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000546 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100547 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000548 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000549 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100550 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000551 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100552 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000553 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000554 DataType::QSymmS16,
555 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100556 };
557
558 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000559 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100560
561 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000562 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100563
564 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000565 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100566
567 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000568}
569
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100570bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
571 const TensorInfo& output,
572 const DepthToSpaceDescriptor& descriptor,
573 Optional<std::string&> reasonIfUnsupported) const
574{
Jan Eilers8eb25602020-03-09 12:13:48 +0000575 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100576 bool supported = true;
577
Sadik Armagan303980c2020-04-17 12:45:14 +0100578 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100579 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000580 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100581 DataType::Float32,
582 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100583 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000584 DataType::QAsymmU8,
585 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100586 };
587
588 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
589 "Reference DepthToSpace: input type not supported");
590
591 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
592 "Reference DepthToSpace: output type not supported");
593
594 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
595 "Reference DepthToSpace: input and output types are mismatched");
596
597 return supported;
598}
599
arovir011c7c81b2018-10-08 11:34:28 +0100600bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
601 const TensorInfo& output,
602 const DepthwiseConvolution2dDescriptor& descriptor,
603 const TensorInfo& weights,
604 const Optional<TensorInfo>& biases,
605 Optional<std::string&> reasonIfUnsupported) const
606{
Sadik Armagan303980c2020-04-17 12:45:14 +0100607 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100608 bool supported = true;
609
610 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000611 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100612 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000613 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100614 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100615 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000616 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000617 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100618 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000619 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100620 };
621
622 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
623 "Reference DepthwiseConvolution2d: input is not a supported type.");
624
625 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
626 "Reference DepthwiseConvolution2d: output is not a supported type.");
627
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100628 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
629 "Reference DepthwiseConvolution2d: input and output types mismatched.");
630
Teresa Charlind8df0262019-11-11 12:28:15 +0000631 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000632 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000633 {
Sadik Armagan303980c2020-04-17 12:45:14 +0100634 ARMNN_NO_DEPRECATE_WARN_BEGIN
635 std::array<DataType, 4> supportedWeightTypes =
636 {
637 DataType::QAsymmS8,
638 DataType::QAsymmU8,
639 DataType::QSymmS8,
640 DataType::QuantizedSymm8PerAxis // deprecated
641 };
642 ARMNN_NO_DEPRECATE_WARN_END
Teresa Charlind8df0262019-11-11 12:28:15 +0000643
644 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +0100645 "Reference DepthwiseConvolution2d: weights type not supported for "
646 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +0000647 }
648 else
649 {
650 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
651 "Reference DepthwiseConvolution2d: weights is not a supported type.");
652
653 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
654 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
655 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100656
657 if (biases.has_value())
658 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000659 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100660 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000661 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100662 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100663 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100664 DataType::Signed32
665 };
666 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
667 "Reference DepthwiseConvolution2d: biases is not a supported type.");
668 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100669
670 return supported;
671
arovir011c7c81b2018-10-08 11:34:28 +0100672}
673
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000674bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
675 const TensorInfo& output,
676 Optional<std::string&> reasonIfUnsupported) const
677{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100678 bool supported = true;
679
Ryan OShea9add1202020-02-07 10:06:33 +0000680 std::array<DataType,4> supportedInputTypes = {
681 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000682 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000683 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000684 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100685 };
686
687 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000688 "Reference for Dequantize layer: input type not supported.");
689
690 supported &= CheckSupportRule( TypeNotPerAxisQuantized(input), reasonIfUnsupported,
691 "Reference for Dequantize layer: per-axis quantized input not support .");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100692
Derek Lambertid466a542020-01-22 15:37:29 +0000693 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
694 "Reference dequantize: per-axis quantized input not support .");
695
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000696 std::array<DataType,3> supportedOutputTypes = {
697 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +0000698 DataType::Float32,
699 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100700 };
701
702 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000703 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100704
705 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000706 "Reference for Dequantize layer: input/output shapes have different num total "
707 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100708
709 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000710}
711
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000712bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
713 const TensorInfo& scores,
714 const TensorInfo& anchors,
715 const TensorInfo& detectionBoxes,
716 const TensorInfo& detectionClasses,
717 const TensorInfo& detectionScores,
718 const TensorInfo& numDetections,
719 const DetectionPostProcessDescriptor& descriptor,
720 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000721{
Jan Eilers8eb25602020-03-09 12:13:48 +0000722 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +0000723
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100724 bool supported = true;
725
Sadik Armagan303980c2020-04-17 12:45:14 +0100726 std::array<DataType,5> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100727 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000728 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100729 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100730 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000731 DataType::QAsymmU8,
732 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100733 };
734
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000735 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100736 "Reference DetectionPostProcess: input 0 is not a supported type.");
737
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000738 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100739 "Reference DetectionPostProcess: input 1 is not a supported type.");
740
741 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000742}
743
Pablo Tellof0bd6832019-04-26 17:58:13 +0100744bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
745 const TensorInfo& output,
746 const DepthwiseConvolution2dDescriptor& descriptor,
747 const TensorInfo& weights,
748 const Optional<TensorInfo>& biases,
749 Optional<std::string&> reasonIfUnsupported) const
750{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100751 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100752}
753
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100754bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100755 const TensorInfo& input1,
756 const TensorInfo& output,
757 Optional<std::string&> reasonIfUnsupported) const
758{
Sadik Armagan2999a022019-04-09 14:20:12 +0100759 bool supported = true;
760
Sadik Armagan303980c2020-04-17 12:45:14 +0100761 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000762 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100763 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100764 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100765 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000766 DataType::QAsymmU8,
767 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +0100768 };
769
770 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
771 "Reference division: input 0 is not a supported type.");
772
773 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
774 "Reference division: input 1 is not a supported type.");
775
776 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
777 "Reference division: output is not a supported type.");
778
779 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
780 "Reference division: input 0 and Input 1 types are mismatched");
781
782 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
783 "Reference division: input and output types are mismatched");
784
785 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
786 "Reference division: shapes are not suitable for implicit broadcast.");
787
788 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100789}
790
josh minor4a3c6102020-01-06 16:40:46 -0600791bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
792 const TensorInfo& output,
793 const ElementwiseUnaryDescriptor& descriptor,
794 Optional<std::string&> reasonIfUnsupported) const
795{
Jan Eilers8eb25602020-03-09 12:13:48 +0000796 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -0600797
Sadik Armagan303980c2020-04-17 12:45:14 +0100798 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -0600799 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000800 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -0600801 DataType::Float32,
802 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100803 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -0600804 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +0000805 DataType::QSymmS16,
806 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -0600807 };
808
809 bool supported = true;
810
811 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
812 "Reference elementwise unary: input type not supported");
813
814 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
815 "Reference elementwise unary: output type not supported");
816
817 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
818 "Reference elementwise unary: input and output types not matching");
819
820 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
821 "Reference elementwise unary: input and output shapes"
822 "have different number of total elements");
823
824 return supported;
825}
826
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000827bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
828 const TensorInfo& input1,
829 const TensorInfo& output,
830 Optional<std::string&> reasonIfUnsupported) const
831{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100832 return IsComparisonSupported(input0,
833 input1,
834 output,
835 ComparisonDescriptor(ComparisonOperation::Equal),
836 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000837}
838
arovir011c7c81b2018-10-08 11:34:28 +0100839bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
840 const FakeQuantizationDescriptor& descriptor,
841 Optional<std::string&> reasonIfUnsupported) const
842{
Jan Eilers8eb25602020-03-09 12:13:48 +0000843 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100844 bool supported = true;
845
846 std::array<DataType,1> supportedTypes =
847 {
848 DataType::Float32
849 };
850
851 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
852 "Reference fake quantization: input type not supported.");
853
854 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100855}
856
857bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
858 const TensorInfo& output,
859 Optional<std::string&> reasonIfUnsupported) const
860{
Jan Eilers8eb25602020-03-09 12:13:48 +0000861 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +0100862 bool supported = true;
863
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000864 std::array<DataType,4> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100865 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000866 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +0100867 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100868 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000869 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +0100870 };
871
872 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
873 "Reference Floor: input type not supported.");
874
875 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
876 "Reference Floor: output type not supported.");
877
878 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100879}
880
881bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
882 const TensorInfo& output,
883 const TensorInfo& weights,
884 const TensorInfo& biases,
885 const FullyConnectedDescriptor& descriptor,
886 Optional<std::string&> reasonIfUnsupported) const
887{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100888 bool supported = true;
889
890 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000891 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100892 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000893 DataType::BFloat16,
894 DataType::Float32,
895 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000896 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100897 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000898 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100899 };
900
901 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
902 "Reference Fully Connected: input type not supported.");
903
904 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
905 "Reference Fully Connected: output type not supported.");
906
Francis Murtagh46c09d02019-05-28 08:15:28 +0100907 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
908 "Reference Fully Connected: weights type not supported.");
909
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000910 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
911 if (input.GetDataType() == DataType::BFloat16)
912 {
913 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
914 {
915 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
916 supported = false;
917 }
918 }
919 else
920 {
921 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
922 "Reference Fully Connected: input and output types mismatched.");
923 }
924
Francis Murtaghddb1d062020-03-10 13:51:45 +0000925 ARMNN_NO_DEPRECATE_WARN_BEGIN
Sadik Armagan303980c2020-04-17 12:45:14 +0100926 std::array<DataType, 4> supportedWeightTypes =
Francis Murtaghddb1d062020-03-10 13:51:45 +0000927 {
Sadik Armagan303980c2020-04-17 12:45:14 +0100928 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +0000929 DataType::QAsymmU8,
930 DataType::QSymmS8,
931 DataType::QuantizedSymm8PerAxis // deprecated
932 };
933 ARMNN_NO_DEPRECATE_WARN_END
934
935 if (IsQuantized8BitType(input.GetDataType()))
936 {
937
938 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
939 "Reference Fully Connected: weights type not supported for quantized input.");
940 }
941 else
942 {
943 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
944 "Reference Fully Connected: weights is not a supported type.");
945
946 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
947 "Reference Fully Connected: input and weights types mismatched.");
948 }
Francis Murtagh46c09d02019-05-28 08:15:28 +0100949
950 if (descriptor.m_BiasEnabled)
951 {
952 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +0100953 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100954 supportedBiasTypes =
955 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000956 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100957 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100958 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +0100959 DataType::Signed32,
960 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +0100961 };
962
963 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
964 "Reference Fully Connected: bias type not supported.");
965
966 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
967 "Reference Fully Connected: bias and weight types mismatch.");
968
969 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
970 "Reference Fully Connected: bias type inferred from weights is incompatible.");
971
Narumol Prangnawarat366d7232020-04-29 12:58:17 +0100972 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
973 "Reference Fully Connected: bias must have 1 dimension.");
974
Francis Murtagh46c09d02019-05-28 08:15:28 +0100975 }
976
977 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100978}
979
narpra014951d842019-01-18 16:53:53 +0000980bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
981 const armnn::TensorInfo& input1,
982 const armnn::TensorInfo& output,
983 armnn::Optional<std::string&> reasonIfUnsupported) const
984{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100985 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +0100986 std::array<DataType,6> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100987 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000988 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100989 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100990 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100991 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000992 DataType::QAsymmU8,
993 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100994 };
995
996 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
997 "Reference Gather: input type not supported");
998
999 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1000 "Reference Gather: output type not supported");
1001
1002 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1003 "Reference Gather: indices (input1) type not supported");
1004
1005 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1006 "Reference Gather: input and output types not matching");
1007
1008 return supported;
narpra014951d842019-01-18 16:53:53 +00001009}
1010
FrancisMurtagh878f0232018-12-19 10:56:15 +00001011bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
1012 const TensorInfo& input1,
1013 const TensorInfo& output,
1014 Optional<std::string&> reasonIfUnsupported) const
1015{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001016 return IsComparisonSupported(input0,
1017 input1,
1018 output,
1019 ComparisonDescriptor(ComparisonOperation::Greater),
1020 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +00001021}
1022
Derek Lamberti901ea112019-12-10 22:07:09 +00001023bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1024 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001025{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001026 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001027}
1028
Kevin May09ca49c2019-10-09 12:37:34 +01001029bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1030 const TensorInfo& output,
1031 const InstanceNormalizationDescriptor& descriptor,
1032 Optional<std::string&> reasonIfUnsupported) const
1033{
Jan Eilers8eb25602020-03-09 12:13:48 +00001034 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001035 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001036 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001037 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001038 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001039 DataType::Float32,
1040 DataType::Float16
1041 };
1042
1043 bool supported = true;
1044
1045 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1046 "Reference Instance Normalization: input type not supported.");
1047
1048 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1049 "Reference Instance Normalization: output type not supported.");
1050
1051 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1052 "Reference Instance Normalization: input and output types mismatched.");
1053
1054 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1055 "Reference Instance Normalization: input and output shapes have different "
1056 "num total elements.");
1057
1058 return supported;
1059}
1060
arovir011c7c81b2018-10-08 11:34:28 +01001061bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1062 const TensorInfo& output,
1063 const L2NormalizationDescriptor& descriptor,
1064 Optional<std::string&> reasonIfUnsupported) const
1065{
Jan Eilers8eb25602020-03-09 12:13:48 +00001066 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001067 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001068 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001069 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001070 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001071 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001072 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001073 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001074 DataType::QAsymmU8,
1075 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001076 };
1077
1078 bool supported = true;
1079
1080 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1081 "Reference L2normalization: input type not supported.");
1082
1083 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1084 "Reference L2normalization: output type not supported.");
1085
1086 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1087 "Reference L2normalization: input and output types mismatched.");
1088
1089 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1090 "Reference L2normalization: input and output shapes have different "
1091 "num total elements.");
1092
1093 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001094}
1095
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001096bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1097 const TensorInfo& output,
1098 const LogSoftmaxDescriptor& descriptor,
1099 Optional<std::string&> reasonIfUnsupported) const
1100{
Jan Eilers8eb25602020-03-09 12:13:48 +00001101 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001102
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001103 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001104 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001105 DataType::BFloat16,
1106 DataType::Float32,
1107 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001108 };
1109
1110 bool supported = true;
1111 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1112 "Reference LogSoftmax: input type not supported");
1113
1114 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1115 "Reference LogSoftmax: output type not supported");
1116
1117 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1118 "Reference LogSoftmax: input and output types do not match");
1119
1120 return supported;
1121}
1122
arovir011c7c81b2018-10-08 11:34:28 +01001123bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1124 const TensorInfo& outputStateIn,
1125 const TensorInfo& cellStateIn,
1126 const TensorInfo& scratchBuffer,
1127 const TensorInfo& outputStateOut,
1128 const TensorInfo& cellStateOut,
1129 const TensorInfo& output,
1130 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001131 const LstmInputParamsInfo& paramsInfo,
1132 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001133{
Jan Eilers8eb25602020-03-09 12:13:48 +00001134 IgnoreUnused(descriptor);
1135 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001136
1137 bool supported = true;
1138
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001139 std::array<DataType,3> supportedTypes = {
1140 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001141 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001142 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001143 };
1144
Jan Eilersd01a83c2019-07-03 18:20:40 +01001145 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001146 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1147 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001148 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1149 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001150 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1151 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001152 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1153 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001154 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1155 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001156 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1157 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001158 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1159 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001160 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001161 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001162 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001163 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001164 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001165 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001166 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001167 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001168 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001169 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001170 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001171 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001172 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001173 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001174 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001175 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001176 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001177 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001178 "Reference Lstm: input and OutputGateBias types are mismatched");
1179 if (!descriptor.m_CifgEnabled)
1180 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001181 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001182 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001183 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001184 reasonIfUnsupported,
1185 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001186 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001187 "Reference Lstm: input and InputGateBias types are mismatched");
1188 if (descriptor.m_PeepholeEnabled)
1189 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001190 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001191 reasonIfUnsupported,
1192 "Reference Lstm: input and CellToInputWeights types are mismatched");
1193 }
1194 }
1195 if (descriptor.m_PeepholeEnabled)
1196 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001197 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001198 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001199 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001200 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1201 }
1202 if (descriptor.m_ProjectionEnabled)
1203 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001204 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001205 "Reference Lstm: input and mProjectionWeights types are mismatched");
1206 if (paramsInfo.m_ProjectionBias != nullptr)
1207 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001208 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001209 "Reference Lstm: input and ProjectionBias types are mismatched");
1210 }
1211 }
1212 if (descriptor.m_LayerNormEnabled)
1213 {
1214 if (!descriptor.m_CifgEnabled)
1215 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001216 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001217 reasonIfUnsupported,
1218 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1219 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001220 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001221 reasonIfUnsupported,
1222 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001223 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001224 reasonIfUnsupported,
1225 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001226 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001227 reasonIfUnsupported,
1228 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1229 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001230
1231 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001232}
1233
saoste012df12b32018-11-28 16:57:20 +00001234bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1235 const TensorInfo& input1,
1236 const TensorInfo& output,
1237 Optional<std::string&> reasonIfUnsupported) const
1238{
Sadik Armagan2999a022019-04-09 14:20:12 +01001239 bool supported = true;
1240
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001241 std::array<DataType,6> supportedTypes = {
1242 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001243 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001244 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001245 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001246 DataType::QAsymmU8,
1247 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001248 };
1249
1250 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1251 "Reference maximum: input 0 is not a supported type.");
1252
1253 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1254 "Reference maximum: input 1 is not a supported type.");
1255
1256 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1257 "Reference maximum: output is not a supported type.");
1258
1259 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1260 "Reference maximum: input 0 and Input 1 types are mismatched");
1261
1262 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1263 "Reference maximum: input and output types are mismatched");
1264
1265 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1266 "Reference maximum: shapes are not suitable for implicit broadcast.");
1267
1268 return supported;
saoste012df12b32018-11-28 16:57:20 +00001269}
1270
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001271bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1272 const TensorInfo& output,
1273 const MeanDescriptor& descriptor,
1274 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001275{
James Conroy4d1ff582019-06-10 17:06:39 +01001276 bool supported = true;
1277 std::string meanLayerStr = "Mean";
1278 std::string outputTensorStr = "output";
1279
Sadik Armagan303980c2020-04-17 12:45:14 +01001280 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001281 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001282 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001283 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001284 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001285 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001286 DataType::QAsymmU8,
1287 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001288 };
1289
1290 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1291 "Reference Mean: input type not supported.");
1292
1293 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1294 "Reference Mean: input and output types are mismatched");
1295
1296 if (descriptor.m_KeepDims)
1297 {
1298 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1299 reasonIfUnsupported,
1300 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1301 output.GetNumDimensions(),
1302 meanLayerStr, outputTensorStr).data());
1303 }
1304 else if (descriptor.m_Axis.empty())
1305 {
1306 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1307 reasonIfUnsupported,
1308 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1309 meanLayerStr, outputTensorStr).data());
1310 }
1311 else
1312 {
1313 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1314
1315 if (outputDim > 0)
1316 {
1317 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1318 reasonIfUnsupported,
1319 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1320 meanLayerStr, outputTensorStr).data());
1321 }
1322 else
1323 {
1324 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1325 reasonIfUnsupported,
1326 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1327 meanLayerStr, outputTensorStr).data());
1328 }
1329 }
1330
1331 return supported;
narpra0132b90462018-09-13 11:07:48 +01001332}
1333
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001334bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001335 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001336 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001337 Optional<std::string&> reasonIfUnsupported) const
1338{
Jim Flynne242f2d2019-05-22 14:24:13 +01001339 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001340}
1341
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001342bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1343 const TensorInfo &output,
1344 Optional<std::string &> reasonIfUnsupported) const
1345{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001346 bool supported = true;
1347
Sadik Armagan303980c2020-04-17 12:45:14 +01001348 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001349 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001350 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001351 DataType::Float32,
1352 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001353 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001354 DataType::QAsymmU8,
1355 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001356 DataType::Boolean
1357 };
1358
1359 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1360 "Reference MemCopy: input type not supported");
1361
1362 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1363 "Reference MemCopy: output type not supported");
1364
1365 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1366 "Reference MemCopy: input and output types are mismatched");
1367
1368 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001369}
1370
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001371bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1372 const TensorInfo& input1,
1373 const TensorInfo& output,
1374 Optional<std::string&> reasonIfUnsupported) const
1375{
Sadik Armagan2999a022019-04-09 14:20:12 +01001376 bool supported = true;
1377
Sadik Armagan303980c2020-04-17 12:45:14 +01001378 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001379 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001380 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001381 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001382 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001383 DataType::QAsymmU8,
1384 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001385 };
1386
1387 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1388 "Reference minimum: input 0 is not a supported type.");
1389
1390 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1391 "Reference minimum: input 1 is not a supported type.");
1392
1393 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1394 "Reference minimum: output is not a supported type.");
1395
1396 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1397 "Reference minimum: input 0 and Input 1 types are mismatched");
1398
1399 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1400 "Reference minimum: input and output types are mismatched");
1401
1402 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1403 "Reference minimum: shapes are not suitable for implicit broadcast.");
1404
1405 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001406}
1407
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001408bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1409 const TensorInfo& input1,
1410 const TensorInfo& output,
1411 Optional<std::string&> reasonIfUnsupported) const
1412{
Sadik Armagan2999a022019-04-09 14:20:12 +01001413 bool supported = true;
1414
Keith Davis67e6c542020-02-19 10:08:33 +00001415 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001416 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001417 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001418 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001419 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001420 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001421 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001422 };
1423
1424 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1425 "Reference multiplication: input 0 is not a supported type.");
1426
1427 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1428 "Reference multiplication: input 1 is not a supported type.");
1429
1430 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1431 "Reference multiplication: output is not a supported type.");
1432
1433 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1434 "Reference multiplication: input 0 and Input 1 types are mismatched");
1435
1436 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1437 "Reference multiplication: input and output types are mismatched");
1438
1439 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1440 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1441
1442 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001443}
1444
1445bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1446 const TensorInfo& output,
1447 const NormalizationDescriptor& descriptor,
1448 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001449{
Jan Eilers8eb25602020-03-09 12:13:48 +00001450 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001451
1452 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001453 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001454 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001455 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001456 DataType::Float16,
1457 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001458 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001459 DataType::QAsymmU8,
1460 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001461 };
1462
1463 bool supported = true;
1464
1465 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1466 "Reference normalization: input type not supported.");
1467
1468 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1469 "Reference normalization: output type not supported.");
1470
1471 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1472 "Reference normalization: input and output shapes have different "
1473 "num total elements.");
1474
1475 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001476}
1477
Derek Lamberti901ea112019-12-10 22:07:09 +00001478bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1479 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001480{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001481 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001482}
1483
1484bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1485 const TensorInfo& output,
1486 const PadDescriptor& descriptor,
1487 Optional<std::string&> reasonIfUnsupported) const
1488{
Jan Eilers8eb25602020-03-09 12:13:48 +00001489 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001490 bool supported = true;
1491
1492 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01001493 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001494 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001495 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001496 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001497 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001498 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001499 DataType::QAsymmU8,
1500 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001501 };
1502
1503 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1504 "Reference pad: input is not a supported type.");
1505
1506 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1507 "Reference pad: output is not a supported type.");
1508
1509 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1510 "Reference pad: input and output types are mismatched.");
1511
1512 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001513}
1514
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001515bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1516 const TensorInfo& output,
1517 const PermuteDescriptor& descriptor,
1518 Optional<std::string&> reasonIfUnsupported) const
1519{
Jan Eilers8eb25602020-03-09 12:13:48 +00001520 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001521 bool supported = true;
1522
1523 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01001524 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001525 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001526 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001527 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001528 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001529 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001530 DataType::QAsymmU8,
1531 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001532 };
1533
1534 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1535 "Reference permute: input is not a supported type.");
1536
1537 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1538 "Reference permute: output is not a supported type.");
1539
1540 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1541 "Reference permute: input and output types are mismatched.");
1542
1543 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001544}
1545
1546bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1547 const TensorInfo& output,
1548 const Pooling2dDescriptor& descriptor,
1549 Optional<std::string&> reasonIfUnsupported) const
1550{
Jan Eilers8eb25602020-03-09 12:13:48 +00001551 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001552 bool supported = true;
1553
1554 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001555 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001556 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001557 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001558 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001559 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001560 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001561 DataType::QAsymmU8,
1562 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001563 };
1564
1565 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1566 "Reference poolind2d: input is not a supported type.");
1567
1568 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1569 "Reference poolind2d: output is not a supported type.");
1570
1571 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1572 "Reference poolind2d: input and output types are mismatched.");
1573
1574 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001575}
1576
James Conroy4f1f8992020-04-29 20:01:10 +01001577bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
1578 const TensorInfo& previousOutputIn,
1579 const TensorInfo& previousCellStateIn,
1580 const TensorInfo& outputStateOut,
1581 const TensorInfo& cellStateOut,
1582 const TensorInfo& output,
1583 const QLstmDescriptor& descriptor,
1584 const LstmInputParamsInfo& paramsInfo,
1585 Optional<std::string&> reasonIfUnsupported) const
1586{
1587 IgnoreUnused(input);
1588 IgnoreUnused(previousOutputIn);
1589 IgnoreUnused(previousCellStateIn);
1590 IgnoreUnused(outputStateOut);
1591 IgnoreUnused(cellStateOut);
1592 IgnoreUnused(output);
1593 IgnoreUnused(descriptor);
1594 IgnoreUnused(paramsInfo);
1595
1596 IgnoreUnused(reasonIfUnsupported);
1597
1598 return true;
1599}
1600
Derek Lamberti5f400d62019-03-25 15:41:58 +00001601bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1602 const TensorInfo& output,
1603 Optional<std::string&> reasonIfUnsupported) const
1604{
1605 bool supported = true;
1606
Finn Williamsfd271062019-12-04 14:27:27 +00001607 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001608 std::array<DataType,7> supportedInputTypes = {
1609 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00001610 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001611 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001612 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001613 DataType::QAsymmU8,
1614 DataType::QSymmS8,
1615 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001616 };
1617
1618 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1619 "Reference quantize: input type not supported.");
1620
1621 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001622 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001623 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001624 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001625 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001626 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001627 };
1628 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1629 "Reference quantize: output type not supported.");
1630
1631 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1632 "Reference quantize: input and output shapes have different num total elements.");
1633
1634 return supported;
1635}
1636
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001637bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001638 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001639 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001640 Optional<std::string&> reasonIfUnsupported) const
1641{
Jan Eilers8eb25602020-03-09 12:13:48 +00001642 IgnoreUnused(output);
1643 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001644 // Define supported output types.
Keith Davis0c2eeac2020-02-11 16:51:50 +00001645 std::array<DataType,7> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001646 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001647 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001648 DataType::Float32,
1649 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001650 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001651 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001652 DataType::QAsymmU8,
1653 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001654 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001655
Nina Drozd2f2778f2019-05-27 10:37:05 +01001656 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1657 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001658}
1659
1660bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001661 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001662 Optional<std::string&> reasonIfUnsupported) const
1663{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001664 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001665 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001666 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001667 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001668 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001669 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001670 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001671 DataType::QAsymmU8,
1672 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001673 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001674
1675 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1676 "Reference ResizeBilinear: input type not supported");
1677
1678 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1679 "Reference ResizeBilinear: output type not supported");
1680
1681 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1682 "Reference ResizeBilinear: input and output types not matching");
1683
1684 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001685}
1686
Teresa Charlin970f43b2019-07-01 13:51:07 +01001687bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1688 const TensorInfo& output,
1689 const ResizeDescriptor& descriptor,
1690 Optional<std::string&> reasonIfUnsupported) const
1691{
Jan Eilers8eb25602020-03-09 12:13:48 +00001692 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001693 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001694 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001695 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001696 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001697 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001698 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001699 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001700 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001701 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001702 };
1703
1704 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1705 "Reference Resize: input type not supported");
1706
1707 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1708 "Reference Resize: output type not supported");
1709
1710 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1711 "Reference Resize: input and output types not matching");
1712
1713 return supported;
1714}
1715
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001716bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1717 const TensorInfo& output,
1718 Optional<std::string&> reasonIfUnsupported) const
1719{
josh minor4a3c6102020-01-06 16:40:46 -06001720 return IsElementwiseUnarySupported(input,
1721 output,
1722 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1723 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001724}
1725
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001726bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1727 const TensorInfo& output,
1728 const SliceDescriptor& descriptor,
1729 Optional<std::string&> reasonIfUnsupported) const
1730{
Jan Eilers8eb25602020-03-09 12:13:48 +00001731 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001732 bool supported = true;
1733
Sadik Armagan303980c2020-04-17 12:45:14 +01001734 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001735 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001736 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001737 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001738 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001739 DataType::QAsymmU8,
1740 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001741 };
1742
1743 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1744 "Reference Slice: input type not supported");
1745
1746 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1747 "Reference Slice: output type not supported");
1748
1749 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1750 "Reference Slice: input and output types are mismatched");
1751
1752 return supported;
1753}
1754
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001755bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1756 const TensorInfo& output,
1757 const SoftmaxDescriptor& descriptor,
1758 Optional<std::string&> reasonIfUnsupported) const
1759{
Jan Eilers8eb25602020-03-09 12:13:48 +00001760 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001761 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001762 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001763 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001764 DataType::BFloat16,
1765 DataType::Float32,
1766 DataType::Float16,
1767 DataType::QSymmS8,
1768 DataType::QAsymmS8,
1769 DataType::QAsymmU8,
1770 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001771 };
1772
1773 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001774 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001775
1776 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001777 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001778
1779 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001780 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001781
1782 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001783}
1784
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001785bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1786 const TensorInfo& output,
1787 const SpaceToBatchNdDescriptor& descriptor,
1788 Optional<std::string&> reasonIfUnsupported) const
1789{
Jan Eilers8eb25602020-03-09 12:13:48 +00001790 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001791 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001792 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001793 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001794 DataType::BFloat16,
1795 DataType::Float32,
1796 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001797 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001798 DataType::QAsymmU8,
1799 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001800 };
1801
1802 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1803 "Reference SpaceToBatchNd: input type not supported");
1804
1805 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1806 "Reference SpaceToBatchNd: output type not supported");
1807
1808 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1809 "Reference SpaceToBatchNd: input and output types are mismatched");
1810
1811 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001812}
1813
Keith Davisa57eccb2019-06-14 17:33:22 +01001814bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001815 const TensorInfo& output,
1816 const SpaceToDepthDescriptor& descriptor,
1817 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001818{
1819
Jan Eilers8eb25602020-03-09 12:13:48 +00001820 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01001821 bool supported = true;
1822
Sadik Armagan303980c2020-04-17 12:45:14 +01001823 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001824 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001825 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001826 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001827 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001828 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001829 DataType::QAsymmU8,
1830 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001831 };
1832
1833 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1834 "Reference SpaceToDepth: input type not supported");
1835
1836 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1837 "Reference SpaceToDepth: output type not supported");
1838
1839 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1840 "Reference SpaceToDepth: input and output types are mismatched");
1841
1842 return supported;
1843}
1844
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001845bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1846 const ViewsDescriptor& descriptor,
1847 Optional<std::string&> reasonIfUnsupported) const
1848{
Jan Eilers8eb25602020-03-09 12:13:48 +00001849 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001850 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001851 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001852 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001853 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001854 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001855 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001856 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001857 DataType::QAsymmU8,
1858 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001859 };
1860
1861 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1862 "Reference splitter: input type not supported");
1863
1864 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001865}
1866
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001867bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1868 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1869 const ViewsDescriptor& descriptor,
1870 Optional<std::string&> reasonIfUnsupported) const
1871{
Jan Eilers8eb25602020-03-09 12:13:48 +00001872 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001873 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001874 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001875 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001876 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001877 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001878 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001879 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001880 DataType::QAsymmU8,
1881 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001882 };
1883
1884 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1885 "Reference splitter: output type not supported");
1886 for (const TensorInfo output : outputs)
1887 {
1888 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1889 "Reference splitter: input type not supported");
1890
1891 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1892 "Reference splitter: input and output types mismatched.");
1893 }
1894
1895 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001896}
1897
Matthew Jackson81e601c2019-07-11 12:07:09 +01001898bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1899 const TensorInfo& output,
1900 const StackDescriptor& descriptor,
1901 Optional<std::string&> reasonIfUnsupported) const
1902{
Jan Eilers8eb25602020-03-09 12:13:48 +00001903 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001904
1905 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001906 std::array<DataType,6> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001907 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001908 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001909 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001910 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001911 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001912 DataType::QAsymmU8,
1913 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01001914 };
1915
1916 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1917 "Reference stack: output type not supported");
1918 for (const TensorInfo* input : inputs)
1919 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001920 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001921 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1922 "Reference stack: input type not supported");
1923
1924 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1925 "Reference stack: input and output types mismatched.");
1926 }
1927
1928 return supported;
1929}
1930
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001931bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1932 const TensorInfo& output,
1933 const StridedSliceDescriptor& descriptor,
1934 Optional<std::string&> reasonIfUnsupported) const
1935{
Jan Eilers8eb25602020-03-09 12:13:48 +00001936 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001937 bool supported = true;
1938
Sadik Armagan303980c2020-04-17 12:45:14 +01001939 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001940 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001941 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001942 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001943 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001944 DataType::QAsymmU8,
1945 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001946 };
1947
1948 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1949 "Reference StridedSlice: input type not supported");
1950
1951 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1952 "Reference StridedSlice: output type not supported");
1953
1954 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1955 "Reference StridedSlice: input and output types are mismatched");
1956
1957 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001958}
1959
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001960bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1961 const TensorInfo& input1,
1962 const TensorInfo& output,
1963 Optional<std::string&> reasonIfUnsupported) const
1964{
Sadik Armagan2999a022019-04-09 14:20:12 +01001965 bool supported = true;
1966
Sadik Armagan303980c2020-04-17 12:45:14 +01001967 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001968 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001969 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001970 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001971 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001972 DataType::QAsymmU8,
1973 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001974 };
1975
1976 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1977 "Reference subtraction: input 0 is not a supported type.");
1978
1979 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1980 "Reference subtraction: input 1 is not a supported type.");
1981
1982 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1983 "Reference subtraction: output is not a supported type.");
1984
1985 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1986 "Reference subtraction: input 0 and Input 1 types are mismatched");
1987
1988 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1989 "Reference subtraction: input and output types are mismatched");
1990
1991 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1992 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1993
1994 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001995}
1996
Matteo Martincighab9e5252019-06-13 17:27:46 +01001997bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1998 const TensorInfo& alpha,
1999 const TensorInfo& output,
2000 Optional<std::string&> reasonIfUnsupported) const
2001{
2002 bool supported = true;
2003
Sadik Armagan303980c2020-04-17 12:45:14 +01002004 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002005 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002006 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002007 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002008 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002009 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002010 DataType::QAsymmU8,
2011 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002012 };
2013
2014 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2015 "PReLU: input is not a supported type.");
2016
2017 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2018 "PReLU: alpha is not a supported type.");
2019
2020 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2021 "PReLU: output is not a supported type.");
2022
2023 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2024 "PReLU: input, alpha and output types are mismatched");
2025
2026 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2027 "PReLU: shapes are not suitable for implicit broadcast");
2028
2029 return supported;
2030}
2031
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002032bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2033 const TensorInfo& output,
2034 const TransposeConvolution2dDescriptor& descriptor,
2035 const TensorInfo& weights,
2036 const Optional<TensorInfo>& biases,
2037 Optional<std::string&> reasonIfUnsupported) const
2038{
Jan Eilers8eb25602020-03-09 12:13:48 +00002039 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002040 bool supported = true;
2041
Sadik Armagan303980c2020-04-17 12:45:14 +01002042 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002043 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002044 DataType::BFloat16,
2045 DataType::Float32,
2046 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002047 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002048 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002049 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002050 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002051 };
2052
2053 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2054 "Reference TransposeConvolution2d: input is not a supported type.");
2055
2056 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2057 "Reference TransposeConvolution2d: output is not a supported type.");
2058
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002059 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2060 "Reference TransposeConvolution2d: input and output types mismatched.");
2061
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002062
2063 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002064 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002065 {
Derek Lambertid466a542020-01-22 15:37:29 +00002066 ARMNN_NO_DEPRECATE_WARN_BEGIN
Sadik Armagan303980c2020-04-17 12:45:14 +01002067 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002068 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002069 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002070 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00002071 DataType::QSymmS8,
2072 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002073 };
Derek Lambertid466a542020-01-22 15:37:29 +00002074 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002075
2076 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2077 "Reference TransposeConvolution2d: weights type not supported for "
2078 "quantized input.");
2079 }
2080 else
2081 {
2082 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2083 "Reference TransposeConvolution2d: weights is not a supported type.");
2084
2085 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2086 "Reference TransposeConvolution2d: input and weights types mismatched.");
2087 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002088
2089 if (biases.has_value())
2090 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002091 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002092 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002093 DataType::BFloat16,
2094 DataType::Float32,
2095 DataType::Float16,
2096 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002097 };
2098 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2099 "Reference TransposeConvolution2d: biases is not a supported type.");
2100 }
2101
2102 return supported;
2103}
2104
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002105bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2106 const TensorInfo& output,
2107 const TransposeDescriptor& descriptor,
2108 Optional<std::string&> reasonIfUnsupported) const
2109{
Jan Eilers8eb25602020-03-09 12:13:48 +00002110 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002111 bool supported = true;
2112
2113 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002114 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002115 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002116 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002117 DataType::Float32,
2118 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002119 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002120 DataType::QAsymmU8,
2121 DataType::QSymmS16
2122 };
2123
2124 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2125 "Reference transpose: input is not a supported type.");
2126
2127 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2128 "Reference transpose: output is not a supported type.");
2129
2130 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2131 "Reference transpose: input and output types are mismatched.");
2132
2133 return supported;
2134}
2135
arovir011c7c81b2018-10-08 11:34:28 +01002136} // namespace armnn