blob: 7d5c3b509e2539ec4640abcc13f2783070fa2a85 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
Matteo Martincighe011d202019-11-28 11:35:47 +000012#include <LayerSupportCommon.hpp>
13
Derek Lambertif674aa02019-08-01 15:56:25 +010014#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000015
Matteo Martincighe011d202019-11-28 11:35:47 +000016#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000020#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
James Conroy4d1ff582019-06-10 17:06:39 +010049namespace
50{
51
52std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
53 unsigned int actual,
54 std::string& layerStr,
55 std::string& tensorName)
56{
57 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
58 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
59
60 return errorMsg;
61}
62
63} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000064
Sadik Armagan9199e582019-09-05 17:35:31 +010065bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
66 Optional<std::string&> reasonIfUnsupported) const
67{
josh minor4a3c6102020-01-06 16:40:46 -060068 return IsElementwiseUnarySupported(input,
69 output,
70 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
71 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010072}
73
arovir011c7c81b2018-10-08 11:34:28 +010074bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
75 const TensorInfo& output,
76 const ActivationDescriptor& descriptor,
77 Optional<std::string&> reasonIfUnsupported) const
78{
Derek Lamberti50db4e82019-03-13 14:16:15 +000079 bool supported = true;
80
81 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000082 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +000083 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010084 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000085 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000086 DataType::QAsymmU8,
87 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000088 };
89
90 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
91 "Reference activation: input type not supported.");
92
93 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
94 "Reference activation: output type not supported.");
95
96 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
97 "Reference activation: input and output types mismatched.");
98
99 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
100 "Reference activation: input and output shapes are of different rank.");
101
102
103 struct ActivationFunctionSupported : public Rule
104 {
105 ActivationFunctionSupported(const ActivationDescriptor& desc)
106 {
107 switch(desc.m_Function)
108 {
109 case ActivationFunction::Abs:
110 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000111 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000112 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000113 case ActivationFunction::LeakyReLu:
114 case ActivationFunction::Linear:
115 case ActivationFunction::ReLu:
116 case ActivationFunction::Sigmoid:
117 case ActivationFunction::SoftReLu:
118 case ActivationFunction::Sqrt:
119 case ActivationFunction::Square:
120 case ActivationFunction::TanH:
121 {
122 m_Res = true;
123 break;
124 }
125 default:
126 {
127 m_Res = false;
128 break;
129 }
130 }
131 }
132 };
133
134 // Function is supported
135 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
136 "Reference activation: function not supported.");
137
138 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100139}
140
141bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
142 const TensorInfo& input1,
143 const TensorInfo& output,
144 Optional<std::string&> reasonIfUnsupported) const
145{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000146 bool supported = true;
147
Keith Davis0c2eeac2020-02-11 16:51:50 +0000148 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000149 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100150 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000151 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000152 DataType::QAsymmU8,
153 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000154 };
155
156 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
157 "Reference addition: input 0 is not a supported type.");
158
159 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
160 "Reference addition: input 1 is not a supported type.");
161
162 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
163 "Reference addition: output is not a supported type.");
164
165 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
166 "Reference addition: input 0 and Input 1 types are mismatched");
167
168 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
169 "Reference addition: input and output types are mismatched");
170
171 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
172 "Reference addition: shapes are not suitable for implicit broadcast.");
173
174 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100175}
176
Nikhil Raj68c2c902019-09-19 11:21:11 +0100177bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
178 const armnn::ArgMinMaxDescriptor &descriptor,
179 armnn::Optional<std::string &> reasonIfUnsupported) const
180{
181 ignore_unused(descriptor);
182
Francis Murtagh1939df52019-11-13 15:21:09 +0000183 std::array<DataType, 4> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100184 {
185 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000186 DataType::QAsymmU8,
187 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000188 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100189 };
190
191 bool supported = true;
192
193 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
194 "Reference ArgMinMax: input is not a supported type.");
195 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
196 "Reference ArgMinMax: output type not supported");
197
198 return supported;
199}
200
arovir011c7c81b2018-10-08 11:34:28 +0100201bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
202 const TensorInfo& output,
203 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100204 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100205 const TensorInfo& beta,
206 const TensorInfo& gamma,
207 const BatchNormalizationDescriptor& descriptor,
208 Optional<std::string&> reasonIfUnsupported) const
209{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100210 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100211
Matthew Jackson9bff1442019-09-12 09:08:23 +0100212 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100213 {
214 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100215 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000216 DataType::QAsymmU8,
217 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100218 };
219
220 bool supported = true;
221
222 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
223 "Reference batch normalization: input is not a supported type.");
224
225 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
226 "Reference batch normalization: output is not a supported type.");
227
228 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
229 "Reference batch normalization: input and output types are mismatched");
230
231 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
232 "Reference batch normalization: mean is not a supported type.");
233
234 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
235 "Reference batch normalization: variance is not a supported type.");
236
237 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
238 "Reference batch normalization: beta is not a supported type.");
239
240 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
241 "Reference batch normalization: gamma is not a supported type.");
242
243 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100244}
245
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000246bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
247 const TensorInfo& output,
248 const BatchToSpaceNdDescriptor& descriptor,
249 Optional<std::string&> reasonIfUnsupported) const
250{
251 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100252
253 bool supported = true;
254
255 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
256 std::string inputTensorStr = "input";
257 std::string outputTensorStr = "output";
258
259 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100260 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100261 {
262 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100263 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000264 DataType::QAsymmU8,
265 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100266 };
267
268 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
269 "Reference BatchToSpaceNd: input type not supported.");
270
271 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
272 "Reference BatchToSpaceNd: output type not supported.");
273
274 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
275 "Reference BatchToSpaceNd: input and output types mismatched.");
276
277 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
278 reasonIfUnsupported,
279 CreateIncorrectDimensionsErrorMsg(4,
280 output.GetNumDimensions(),
281 batchToSpaceNdLayerStr,
282 outputTensorStr).data());
283
284 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
285 reasonIfUnsupported,
286 CreateIncorrectDimensionsErrorMsg(4,
287 input.GetNumDimensions(),
288 batchToSpaceNdLayerStr,
289 inputTensorStr).data());
290
291 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000292}
293
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100294bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
295 const TensorInfo& input1,
296 const TensorInfo& output,
297 const ComparisonDescriptor& descriptor,
298 Optional<std::string&> reasonIfUnsupported) const
299{
300 boost::ignore_unused(descriptor);
301
302 std::array<DataType, 4> supportedInputTypes =
303 {
304 DataType::Float32,
305 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000306 DataType::QAsymmU8,
307 DataType::QSymmS16
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100308 };
309
310 bool supported = true;
311 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
312 "Reference comparison: input 0 is not a supported type");
313
314 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
315 "Reference comparison: input 0 and Input 1 types are mismatched");
316
317 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
318 "Reference comparison: output is not of type Boolean");
319
320 return supported;
321}
322
Jim Flynn906f9462019-05-10 13:55:21 +0100323bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
324 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100325 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100326 Optional<std::string&> reasonIfUnsupported) const
327{
Jim Flynne242f2d2019-05-22 14:24:13 +0100328 ignore_unused(descriptor);
329
330 bool supported = true;
Keith Davis5204aa82020-01-27 15:24:59 +0000331 std::array<DataType,5> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100332 {
333 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100334 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000335 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +0000336 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000337 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100338 };
339
340 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
341 "Reference concatenation: output type not supported");
342 for (const TensorInfo* input : inputs)
343 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100344 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100345 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
346 "Reference concatenation: input type not supported");
347
348 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
349 "Reference concatenation: input and output types mismatched.");
350 }
351
352 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100353}
354
arovir011c7c81b2018-10-08 11:34:28 +0100355bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
356 Optional<std::string&> reasonIfUnsupported) const
357{
Keith Davis67e6c542020-02-19 10:08:33 +0000358 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100359 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100360 DataType::Float32,
361 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000362 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +0000363 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000364 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000365 DataType::QSymmS16
Nina Drozd58ef2c62019-05-16 12:09:18 +0100366 };
367
368 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
369 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100370}
371
372bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
373 const TensorInfo& output,
374 Optional<std::string&> reasonIfUnsupported) const
375{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100376 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
377 input.GetDataType(),
378 &TrueFunc<>,
379 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000380 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000381 &FalseFuncI32<>,
382 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100383 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
384 output.GetDataType(),
385 &FalseOutputFuncF16<>,
386 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000387 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000388 &FalseFuncI32<>,
389 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100390}
391
392bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
393 const TensorInfo& output,
394 Optional<std::string&> reasonIfUnsupported) const
395{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100396 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
397 input.GetDataType(),
398 &FalseInputFuncF16<>,
399 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000400 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000401 &FalseFuncI32<>,
402 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100403 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
404 output.GetDataType(),
405 &TrueFunc<>,
406 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000407 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000408 &FalseFuncI32<>,
409 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100410}
411
412bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
413 const TensorInfo& output,
414 const Convolution2dDescriptor& descriptor,
415 const TensorInfo& weights,
416 const Optional<TensorInfo>& biases,
417 Optional<std::string&> reasonIfUnsupported) const
418{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100419 bool supported = true;
420
421 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000422 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000423 {
424 DataType::Float32,
425 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000426 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000427 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000428 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000429 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100430 };
431
432 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000433 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100434
435 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000436 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100437
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100438 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000439 "Reference Convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100440
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000441 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000442 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000443 {
Derek Lambertid466a542020-01-22 15:37:29 +0000444 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000445 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000446 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000447 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000448 DataType::QSymmS8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000449 DataType::QAsymmS8,
Derek Lambertid466a542020-01-22 15:37:29 +0000450 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000451 };
Derek Lambertid466a542020-01-22 15:37:29 +0000452 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000453
454 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000455 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000456 }
457 else
458 {
459 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000460 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000461
462 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000463 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000464 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100465
466 if (biases.has_value())
467 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000468 std::array<DataType,3> biasesSupportedTypes =
469 {
470 DataType::Float32,
471 DataType::Float16,
472 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100473 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000474
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100475 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000476 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100477 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100478 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100479
480 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100481}
482
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000483bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
484 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000485 Optional<std::string&> reasonIfUnsupported) const
486{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100487 bool supported = true;
488
Keith Davis0c2eeac2020-02-11 16:51:50 +0000489 std::array<DataType, 7> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100490 {
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000491 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100492 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000493 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000494 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000495 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000496 DataType::QSymmS16,
497 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100498 };
499
500 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000501 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100502
503 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000504 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100505
506 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000507 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100508
509 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000510}
511
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100512bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
513 const TensorInfo& output,
514 const DepthToSpaceDescriptor& descriptor,
515 Optional<std::string&> reasonIfUnsupported) const
516{
517 ignore_unused(descriptor);
518 bool supported = true;
519
520 std::array<DataType,4> supportedTypes =
521 {
522 DataType::Float32,
523 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000524 DataType::QAsymmU8,
525 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100526 };
527
528 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
529 "Reference DepthToSpace: input type not supported");
530
531 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
532 "Reference DepthToSpace: output type not supported");
533
534 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
535 "Reference DepthToSpace: input and output types are mismatched");
536
537 return supported;
538}
539
arovir011c7c81b2018-10-08 11:34:28 +0100540bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
541 const TensorInfo& output,
542 const DepthwiseConvolution2dDescriptor& descriptor,
543 const TensorInfo& weights,
544 const Optional<TensorInfo>& biases,
545 Optional<std::string&> reasonIfUnsupported) const
546{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100547 bool supported = true;
548
549 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000550 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100551 {
552 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100553 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000554 DataType::QSymmS8,
555 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000556 DataType::QAsymmU8,
557 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100558 };
559
560 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
561 "Reference DepthwiseConvolution2d: input is not a supported type.");
562
563 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
564 "Reference DepthwiseConvolution2d: output is not a supported type.");
565
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100566 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
567 "Reference DepthwiseConvolution2d: input and output types mismatched.");
568
Derek Lambertid466a542020-01-22 15:37:29 +0000569 ARMNN_NO_DEPRECATE_WARN_BEGIN
570 std::array<DataType, 3> supportedWeightTypes =
571 {
572 DataType::QAsymmU8,
573 DataType::QSymmS8,
574 DataType::QuantizedSymm8PerAxis // deprecated
575 };
576 ARMNN_NO_DEPRECATE_WARN_END
577
Teresa Charlind8df0262019-11-11 12:28:15 +0000578 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000579 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000580 {
Teresa Charlind8df0262019-11-11 12:28:15 +0000581
582 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
583 "Reference convolution2d: weights type not supported for quantized input.");
584 }
585 else
586 {
587 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
588 "Reference DepthwiseConvolution2d: weights is not a supported type.");
589
590 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
591 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
592 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100593
594 if (biases.has_value())
595 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100596 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100597 {
598 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100599 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100600 DataType::Signed32
601 };
602 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
603 "Reference DepthwiseConvolution2d: biases is not a supported type.");
604 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100605 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100606
607 return supported;
608
arovir011c7c81b2018-10-08 11:34:28 +0100609}
610
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000611bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
612 const TensorInfo& output,
613 Optional<std::string&> reasonIfUnsupported) const
614{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100615 bool supported = true;
616
Ryan OShea9add1202020-02-07 10:06:33 +0000617 std::array<DataType,4> supportedInputTypes = {
618 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000619 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000620 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000621 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100622 };
623
624 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000625 "Reference for Dequantize layer: input type not supported.");
626
627 supported &= CheckSupportRule( TypeNotPerAxisQuantized(input), reasonIfUnsupported,
628 "Reference for Dequantize layer: per-axis quantized input not support .");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100629
Derek Lambertid466a542020-01-22 15:37:29 +0000630 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
631 "Reference dequantize: per-axis quantized input not support .");
632
Jan Eilersf7107932019-11-01 11:09:36 +0000633 std::array<DataType,2> supportedOutputTypes = {
634 DataType::Float32,
635 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100636 };
637
638 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000639 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100640
641 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000642 "Reference for Dequantize layer: input/output shapes have different num total "
643 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100644
645 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000646}
647
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000648bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
649 const TensorInfo& scores,
650 const TensorInfo& anchors,
651 const TensorInfo& detectionBoxes,
652 const TensorInfo& detectionClasses,
653 const TensorInfo& detectionScores,
654 const TensorInfo& numDetections,
655 const DetectionPostProcessDescriptor& descriptor,
656 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000657{
Derek Lamberti901ea112019-12-10 22:07:09 +0000658 boost::ignore_unused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
659
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100660 bool supported = true;
661
Mike Kelly4992c342019-08-14 11:33:11 +0100662 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100663 {
664 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000665 DataType::QAsymmU8,
666 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100667 };
668
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000669 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100670 "Reference DetectionPostProcess: input 0 is not a supported type.");
671
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000672 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100673 "Reference DetectionPostProcess: input 1 is not a supported type.");
674
675 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000676}
677
Pablo Tellof0bd6832019-04-26 17:58:13 +0100678bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
679 const TensorInfo& output,
680 const DepthwiseConvolution2dDescriptor& descriptor,
681 const TensorInfo& weights,
682 const Optional<TensorInfo>& biases,
683 Optional<std::string&> reasonIfUnsupported) const
684{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100685 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100686}
687
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100688bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100689 const TensorInfo& input1,
690 const TensorInfo& output,
691 Optional<std::string&> reasonIfUnsupported) const
692{
Sadik Armagan2999a022019-04-09 14:20:12 +0100693 bool supported = true;
694
Matthew Jackson9bff1442019-09-12 09:08:23 +0100695 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100696 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100697 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000698 DataType::QAsymmU8,
699 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +0100700 };
701
702 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
703 "Reference division: input 0 is not a supported type.");
704
705 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
706 "Reference division: input 1 is not a supported type.");
707
708 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
709 "Reference division: output is not a supported type.");
710
711 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
712 "Reference division: input 0 and Input 1 types are mismatched");
713
714 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
715 "Reference division: input and output types are mismatched");
716
717 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
718 "Reference division: shapes are not suitable for implicit broadcast.");
719
720 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100721}
722
josh minor4a3c6102020-01-06 16:40:46 -0600723bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
724 const TensorInfo& output,
725 const ElementwiseUnaryDescriptor& descriptor,
726 Optional<std::string&> reasonIfUnsupported) const
727{
728 boost::ignore_unused(descriptor);
729
730 std::array<DataType, 4> supportedTypes =
731 {
732 DataType::Float32,
733 DataType::Float16,
734 DataType::QAsymmU8,
735 DataType::QSymmS16
736 };
737
738 bool supported = true;
739
740 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
741 "Reference elementwise unary: input type not supported");
742
743 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
744 "Reference elementwise unary: output type not supported");
745
746 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
747 "Reference elementwise unary: input and output types not matching");
748
749 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
750 "Reference elementwise unary: input and output shapes"
751 "have different number of total elements");
752
753 return supported;
754}
755
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000756bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
757 const TensorInfo& input1,
758 const TensorInfo& output,
759 Optional<std::string&> reasonIfUnsupported) const
760{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100761 return IsComparisonSupported(input0,
762 input1,
763 output,
764 ComparisonDescriptor(ComparisonOperation::Equal),
765 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000766}
767
arovir011c7c81b2018-10-08 11:34:28 +0100768bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
769 const FakeQuantizationDescriptor& descriptor,
770 Optional<std::string&> reasonIfUnsupported) const
771{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100772 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100773 bool supported = true;
774
775 std::array<DataType,1> supportedTypes =
776 {
777 DataType::Float32
778 };
779
780 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
781 "Reference fake quantization: input type not supported.");
782
783 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100784}
785
786bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
787 const TensorInfo& output,
788 Optional<std::string&> reasonIfUnsupported) const
789{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100790 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100791 bool supported = true;
792
Matthew Jackson9bff1442019-09-12 09:08:23 +0100793 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100794 {
James Conroyb40d7102019-06-04 12:32:09 +0100795 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100796 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000797 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +0100798 };
799
800 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
801 "Reference Floor: input type not supported.");
802
803 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
804 "Reference Floor: output type not supported.");
805
806 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100807}
808
809bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
810 const TensorInfo& output,
811 const TensorInfo& weights,
812 const TensorInfo& biases,
813 const FullyConnectedDescriptor& descriptor,
814 Optional<std::string&> reasonIfUnsupported) const
815{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100816 bool supported = true;
817
818 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100819 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100820 {
821 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100822 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000823 DataType::QAsymmU8,
824 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100825 };
826
827 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
828 "Reference Fully Connected: input type not supported.");
829
830 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
831 "Reference Fully Connected: output type not supported.");
832
833 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
834 "Reference Fully Connected: input and output types mismatched.");
835
836 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
837 "Reference Fully Connected: weights type not supported.");
838
839 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
840 "Reference Fully Connected: input and weight types mismatched.");
841
842 if (descriptor.m_BiasEnabled)
843 {
844 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100845 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100846 supportedBiasTypes =
847 {
848 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100849 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100850 DataType::Signed32
851 };
852
853 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
854 "Reference Fully Connected: bias type not supported.");
855
856 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
857 "Reference Fully Connected: bias and weight types mismatch.");
858
859 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
860 "Reference Fully Connected: bias type inferred from weights is incompatible.");
861
862 }
863
864 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100865}
866
narpra014951d842019-01-18 16:53:53 +0000867bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
868 const armnn::TensorInfo& input1,
869 const armnn::TensorInfo& output,
870 armnn::Optional<std::string&> reasonIfUnsupported) const
871{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100872 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100873 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100874 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100875 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100876 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000877 DataType::QAsymmU8,
878 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100879 };
880
881 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
882 "Reference Gather: input type not supported");
883
884 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
885 "Reference Gather: output type not supported");
886
887 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
888 "Reference Gather: indices (input1) type not supported");
889
890 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
891 "Reference Gather: input and output types not matching");
892
893 return supported;
narpra014951d842019-01-18 16:53:53 +0000894}
895
FrancisMurtagh878f0232018-12-19 10:56:15 +0000896bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
897 const TensorInfo& input1,
898 const TensorInfo& output,
899 Optional<std::string&> reasonIfUnsupported) const
900{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100901 return IsComparisonSupported(input0,
902 input1,
903 output,
904 ComparisonDescriptor(ComparisonOperation::Greater),
905 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000906}
907
Derek Lamberti901ea112019-12-10 22:07:09 +0000908bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
909 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +0100910{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100911 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100912}
913
Kevin May09ca49c2019-10-09 12:37:34 +0100914bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
915 const TensorInfo& output,
916 const InstanceNormalizationDescriptor& descriptor,
917 Optional<std::string&> reasonIfUnsupported) const
918{
919 ignore_unused(descriptor);
920 // Define supported types
921 std::array<DataType, 4> supportedTypes =
922 {
923 DataType::Float32,
924 DataType::Float16
925 };
926
927 bool supported = true;
928
929 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
930 "Reference Instance Normalization: input type not supported.");
931
932 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
933 "Reference Instance Normalization: output type not supported.");
934
935 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
936 "Reference Instance Normalization: input and output types mismatched.");
937
938 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
939 "Reference Instance Normalization: input and output shapes have different "
940 "num total elements.");
941
942 return supported;
943}
944
arovir011c7c81b2018-10-08 11:34:28 +0100945bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
946 const TensorInfo& output,
947 const L2NormalizationDescriptor& descriptor,
948 Optional<std::string&> reasonIfUnsupported) const
949{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100950 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100951 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100952 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100953 {
954 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100955 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000956 DataType::QAsymmU8,
957 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100958 };
959
960 bool supported = true;
961
962 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
963 "Reference L2normalization: input type not supported.");
964
965 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
966 "Reference L2normalization: output type not supported.");
967
968 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
969 "Reference L2normalization: input and output types mismatched.");
970
971 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
972 "Reference L2normalization: input and output shapes have different "
973 "num total elements.");
974
975 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100976}
977
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100978bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
979 const TensorInfo& output,
980 const LogSoftmaxDescriptor& descriptor,
981 Optional<std::string&> reasonIfUnsupported) const
982{
983 ignore_unused(descriptor);
984
985 std::array<DataType, 2> supportedTypes =
986 {
987 DataType::Float32,
988 DataType::Float16
989 };
990
991 bool supported = true;
992 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
993 "Reference LogSoftmax: input type not supported");
994
995 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
996 "Reference LogSoftmax: output type not supported");
997
998 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
999 "Reference LogSoftmax: input and output types do not match");
1000
1001 return supported;
1002}
1003
arovir011c7c81b2018-10-08 11:34:28 +01001004bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1005 const TensorInfo& outputStateIn,
1006 const TensorInfo& cellStateIn,
1007 const TensorInfo& scratchBuffer,
1008 const TensorInfo& outputStateOut,
1009 const TensorInfo& cellStateOut,
1010 const TensorInfo& output,
1011 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001012 const LstmInputParamsInfo& paramsInfo,
1013 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001014{
telsoa01c577f2c2018-08-31 09:22:23 +01001015 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +01001016 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001017
1018 bool supported = true;
1019
1020 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001021 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001022 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001023 };
1024
Jan Eilersd01a83c2019-07-03 18:20:40 +01001025 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001026 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1027 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001028 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1029 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001030 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1031 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001032 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1033 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001034 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1035 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001036 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1037 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001038 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1039 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001040 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001041 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001042 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001043 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001044 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001045 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001046 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001047 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001048 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001049 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001050 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001051 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001052 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001053 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001054 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001055 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001056 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001057 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001058 "Reference Lstm: input and OutputGateBias types are mismatched");
1059 if (!descriptor.m_CifgEnabled)
1060 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001061 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001062 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001063 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001064 reasonIfUnsupported,
1065 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001066 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001067 "Reference Lstm: input and InputGateBias types are mismatched");
1068 if (descriptor.m_PeepholeEnabled)
1069 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001070 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001071 reasonIfUnsupported,
1072 "Reference Lstm: input and CellToInputWeights types are mismatched");
1073 }
1074 }
1075 if (descriptor.m_PeepholeEnabled)
1076 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001077 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001078 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001079 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001080 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1081 }
1082 if (descriptor.m_ProjectionEnabled)
1083 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001084 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001085 "Reference Lstm: input and mProjectionWeights types are mismatched");
1086 if (paramsInfo.m_ProjectionBias != nullptr)
1087 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001088 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001089 "Reference Lstm: input and ProjectionBias types are mismatched");
1090 }
1091 }
1092 if (descriptor.m_LayerNormEnabled)
1093 {
1094 if (!descriptor.m_CifgEnabled)
1095 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001096 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001097 reasonIfUnsupported,
1098 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1099 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001100 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001101 reasonIfUnsupported,
1102 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001103 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001104 reasonIfUnsupported,
1105 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001106 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001107 reasonIfUnsupported,
1108 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1109 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001110
1111 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001112}
1113
saoste012df12b32018-11-28 16:57:20 +00001114bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1115 const TensorInfo& input1,
1116 const TensorInfo& output,
1117 Optional<std::string&> reasonIfUnsupported) const
1118{
Sadik Armagan2999a022019-04-09 14:20:12 +01001119 bool supported = true;
1120
Keith Davis5204aa82020-01-27 15:24:59 +00001121 std::array<DataType,5> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001122 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001123 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001124 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001125 DataType::QAsymmU8,
1126 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001127 };
1128
1129 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1130 "Reference maximum: input 0 is not a supported type.");
1131
1132 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1133 "Reference maximum: input 1 is not a supported type.");
1134
1135 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1136 "Reference maximum: output is not a supported type.");
1137
1138 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1139 "Reference maximum: input 0 and Input 1 types are mismatched");
1140
1141 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1142 "Reference maximum: input and output types are mismatched");
1143
1144 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1145 "Reference maximum: shapes are not suitable for implicit broadcast.");
1146
1147 return supported;
saoste012df12b32018-11-28 16:57:20 +00001148}
1149
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001150bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1151 const TensorInfo& output,
1152 const MeanDescriptor& descriptor,
1153 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001154{
James Conroy4d1ff582019-06-10 17:06:39 +01001155 bool supported = true;
1156 std::string meanLayerStr = "Mean";
1157 std::string outputTensorStr = "output";
1158
Matthew Jackson252df3a2019-09-11 09:19:18 +01001159 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001160 {
1161 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001162 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001163 DataType::QAsymmU8,
1164 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001165 };
1166
1167 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1168 "Reference Mean: input type not supported.");
1169
1170 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1171 "Reference Mean: input and output types are mismatched");
1172
1173 if (descriptor.m_KeepDims)
1174 {
1175 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1176 reasonIfUnsupported,
1177 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1178 output.GetNumDimensions(),
1179 meanLayerStr, outputTensorStr).data());
1180 }
1181 else if (descriptor.m_Axis.empty())
1182 {
1183 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1184 reasonIfUnsupported,
1185 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1186 meanLayerStr, outputTensorStr).data());
1187 }
1188 else
1189 {
1190 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1191
1192 if (outputDim > 0)
1193 {
1194 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1195 reasonIfUnsupported,
1196 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1197 meanLayerStr, outputTensorStr).data());
1198 }
1199 else
1200 {
1201 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1202 reasonIfUnsupported,
1203 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1204 meanLayerStr, outputTensorStr).data());
1205 }
1206 }
1207
1208 return supported;
narpra0132b90462018-09-13 11:07:48 +01001209}
1210
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001211bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001212 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001213 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001214 Optional<std::string&> reasonIfUnsupported) const
1215{
Jim Flynne242f2d2019-05-22 14:24:13 +01001216 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001217}
1218
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001219bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1220 const TensorInfo &output,
1221 Optional<std::string &> reasonIfUnsupported) const
1222{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001223 bool supported = true;
1224
1225 std::array<DataType,5> supportedTypes =
1226 {
1227 DataType::Float32,
1228 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001229 DataType::QAsymmU8,
1230 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001231 DataType::Boolean
1232 };
1233
1234 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1235 "Reference MemCopy: input type not supported");
1236
1237 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1238 "Reference MemCopy: output type not supported");
1239
1240 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1241 "Reference MemCopy: input and output types are mismatched");
1242
1243 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001244}
1245
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001246bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1247 const TensorInfo& input1,
1248 const TensorInfo& output,
1249 Optional<std::string&> reasonIfUnsupported) const
1250{
Sadik Armagan2999a022019-04-09 14:20:12 +01001251 bool supported = true;
1252
Matthew Jackson9bff1442019-09-12 09:08:23 +01001253 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001254 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001255 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001256 DataType::QAsymmU8,
1257 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001258 };
1259
1260 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1261 "Reference minimum: input 0 is not a supported type.");
1262
1263 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1264 "Reference minimum: input 1 is not a supported type.");
1265
1266 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1267 "Reference minimum: output is not a supported type.");
1268
1269 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1270 "Reference minimum: input 0 and Input 1 types are mismatched");
1271
1272 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1273 "Reference minimum: input and output types are mismatched");
1274
1275 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1276 "Reference minimum: shapes are not suitable for implicit broadcast.");
1277
1278 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001279}
1280
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001281bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1282 const TensorInfo& input1,
1283 const TensorInfo& output,
1284 Optional<std::string&> reasonIfUnsupported) const
1285{
Sadik Armagan2999a022019-04-09 14:20:12 +01001286 bool supported = true;
1287
Keith Davis67e6c542020-02-19 10:08:33 +00001288 std::array<DataType,6> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001289 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001290 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001291 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001292 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001293 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001294 };
1295
1296 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1297 "Reference multiplication: input 0 is not a supported type.");
1298
1299 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1300 "Reference multiplication: input 1 is not a supported type.");
1301
1302 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1303 "Reference multiplication: output is not a supported type.");
1304
1305 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1306 "Reference multiplication: input 0 and Input 1 types are mismatched");
1307
1308 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1309 "Reference multiplication: input and output types are mismatched");
1310
1311 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1312 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1313
1314 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001315}
1316
1317bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1318 const TensorInfo& output,
1319 const NormalizationDescriptor& descriptor,
1320 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001321{
Nina Drozd661dfa72018-10-02 11:14:17 +01001322 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001323
1324 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001325 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001326 {
1327 DataType::Float16,
1328 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001329 DataType::QAsymmU8,
1330 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001331 };
1332
1333 bool supported = true;
1334
1335 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1336 "Reference normalization: input type not supported.");
1337
1338 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1339 "Reference normalization: output type not supported.");
1340
1341 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1342 "Reference normalization: input and output shapes have different "
1343 "num total elements.");
1344
1345 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001346}
1347
Derek Lamberti901ea112019-12-10 22:07:09 +00001348bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1349 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001350{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001351 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001352}
1353
1354bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1355 const TensorInfo& output,
1356 const PadDescriptor& descriptor,
1357 Optional<std::string&> reasonIfUnsupported) const
1358{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001359 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001360 bool supported = true;
1361
1362 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001363 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001364 {
1365 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001366 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001367 DataType::QAsymmU8,
1368 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001369 };
1370
1371 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1372 "Reference pad: input is not a supported type.");
1373
1374 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1375 "Reference pad: output is not a supported type.");
1376
1377 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1378 "Reference pad: input and output types are mismatched.");
1379
1380 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001381}
1382
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001383bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1384 const TensorInfo& output,
1385 const PermuteDescriptor& descriptor,
1386 Optional<std::string&> reasonIfUnsupported) const
1387{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001388 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001389 bool supported = true;
1390
1391 // Define supported output and inputs types.
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001392 std::array<DataType, 4> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001393 {
1394 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001395 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001396 DataType::QAsymmU8,
1397 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001398 };
1399
1400 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1401 "Reference permute: input is not a supported type.");
1402
1403 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1404 "Reference permute: output is not a supported type.");
1405
1406 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1407 "Reference permute: input and output types are mismatched.");
1408
1409 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001410}
1411
1412bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1413 const TensorInfo& output,
1414 const Pooling2dDescriptor& descriptor,
1415 Optional<std::string&> reasonIfUnsupported) const
1416{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001417 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001418 bool supported = true;
1419
1420 // Define supported output and inputs types.
Keith Davis67e6c542020-02-19 10:08:33 +00001421 std::array<DataType,5> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001422 {
1423 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001424 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001425 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001426 DataType::QAsymmU8,
1427 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001428 };
1429
1430 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1431 "Reference poolind2d: input is not a supported type.");
1432
1433 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1434 "Reference poolind2d: output is not a supported type.");
1435
1436 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1437 "Reference poolind2d: input and output types are mismatched.");
1438
1439 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001440}
1441
Derek Lamberti5f400d62019-03-25 15:41:58 +00001442bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1443 const TensorInfo& output,
1444 Optional<std::string&> reasonIfUnsupported) const
1445{
1446 bool supported = true;
1447
Finn Williamsfd271062019-12-04 14:27:27 +00001448 // Define supported input types.
Ryan OShea9add1202020-02-07 10:06:33 +00001449 std::array<DataType,6> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00001450 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001451 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001452 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001453 DataType::QAsymmU8,
1454 DataType::QSymmS8,
1455 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001456 };
1457
1458 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1459 "Reference quantize: input type not supported.");
1460
1461 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001462 std::array<DataType,4> supportedOutputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001463 DataType::QAsymmU8,
Ryan OShea9add1202020-02-07 10:06:33 +00001464 DataType::QAsymmS8,
Finn Williamsfd271062019-12-04 14:27:27 +00001465 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001466 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001467 };
1468 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1469 "Reference quantize: output type not supported.");
1470
1471 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1472 "Reference quantize: input and output shapes have different num total elements.");
1473
1474 return supported;
1475}
1476
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001477bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001478 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001479 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001480 Optional<std::string&> reasonIfUnsupported) const
1481{
Kevin Maya023c402019-12-12 17:28:05 +00001482 ignore_unused(output);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001483 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001484 // Define supported output types.
Keith Davis0c2eeac2020-02-11 16:51:50 +00001485 std::array<DataType,7> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001486 {
1487 DataType::Float32,
1488 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001489 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001490 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001491 DataType::QAsymmU8,
1492 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001493 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001494
Nina Drozd2f2778f2019-05-27 10:37:05 +01001495 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1496 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001497}
1498
1499bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001500 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001501 Optional<std::string&> reasonIfUnsupported) const
1502{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001503 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001504 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001505 {
1506 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001507 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001508 DataType::QAsymmU8,
1509 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001510 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001511
1512 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1513 "Reference ResizeBilinear: input type not supported");
1514
1515 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1516 "Reference ResizeBilinear: output type not supported");
1517
1518 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1519 "Reference ResizeBilinear: input and output types not matching");
1520
1521 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001522}
1523
Teresa Charlin970f43b2019-07-01 13:51:07 +01001524bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1525 const TensorInfo& output,
1526 const ResizeDescriptor& descriptor,
1527 Optional<std::string&> reasonIfUnsupported) const
1528{
Derek Lamberti901ea112019-12-10 22:07:09 +00001529 boost::ignore_unused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001530 bool supported = true;
Keith Davis5204aa82020-01-27 15:24:59 +00001531 std::array<DataType,5> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001532 {
1533 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001534 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001535 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001536 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001537 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001538 };
1539
1540 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1541 "Reference Resize: input type not supported");
1542
1543 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1544 "Reference Resize: output type not supported");
1545
1546 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1547 "Reference Resize: input and output types not matching");
1548
1549 return supported;
1550}
1551
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001552bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1553 const TensorInfo& output,
1554 Optional<std::string&> reasonIfUnsupported) const
1555{
josh minor4a3c6102020-01-06 16:40:46 -06001556 return IsElementwiseUnarySupported(input,
1557 output,
1558 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1559 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001560}
1561
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001562bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1563 const TensorInfo& output,
1564 const SliceDescriptor& descriptor,
1565 Optional<std::string&> reasonIfUnsupported) const
1566{
Derek Lamberti901ea112019-12-10 22:07:09 +00001567 boost::ignore_unused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001568 bool supported = true;
1569
1570 std::array<DataType, 3> supportedTypes =
1571 {
1572 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001573 DataType::QAsymmU8,
1574 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001575 };
1576
1577 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1578 "Reference Slice: input type not supported");
1579
1580 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1581 "Reference Slice: output type not supported");
1582
1583 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1584 "Reference Slice: input and output types are mismatched");
1585
1586 return supported;
1587}
1588
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001589bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1590 const TensorInfo& output,
1591 const SoftmaxDescriptor& descriptor,
1592 Optional<std::string&> reasonIfUnsupported) const
1593{
Derek Lamberti901ea112019-12-10 22:07:09 +00001594 boost::ignore_unused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001595 bool supported = true;
Keith Davis0c2eeac2020-02-11 16:51:50 +00001596 std::array<DataType,6> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001597 {
1598 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001599 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001600 DataType::QSymmS8,
1601 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001602 DataType::QAsymmU8,
1603 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001604 };
1605
1606 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001607 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001608
1609 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001610 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001611
1612 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001613 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001614
1615 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001616}
1617
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001618bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1619 const TensorInfo& output,
1620 const SpaceToBatchNdDescriptor& descriptor,
1621 Optional<std::string&> reasonIfUnsupported) const
1622{
Derek Lamberti901ea112019-12-10 22:07:09 +00001623 boost::ignore_unused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001624 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001625 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001626 {
1627 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001628 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001629 DataType::QAsymmU8,
1630 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001631 };
1632
1633 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1634 "Reference SpaceToBatchNd: input type not supported");
1635
1636 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1637 "Reference SpaceToBatchNd: output type not supported");
1638
1639 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1640 "Reference SpaceToBatchNd: input and output types are mismatched");
1641
1642 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001643}
1644
Keith Davisa57eccb2019-06-14 17:33:22 +01001645bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001646 const TensorInfo& output,
1647 const SpaceToDepthDescriptor& descriptor,
1648 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001649{
1650
1651 ignore_unused(descriptor);
1652 bool supported = true;
1653
Matthew Jackson9bff1442019-09-12 09:08:23 +01001654 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001655 {
1656 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001657 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001658 DataType::QAsymmU8,
1659 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001660 };
1661
1662 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1663 "Reference SpaceToDepth: input type not supported");
1664
1665 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1666 "Reference SpaceToDepth: output type not supported");
1667
1668 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1669 "Reference SpaceToDepth: input and output types are mismatched");
1670
1671 return supported;
1672}
1673
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001674bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1675 const ViewsDescriptor& descriptor,
1676 Optional<std::string&> reasonIfUnsupported) const
1677{
1678 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001679 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001680 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001681 {
1682 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001683 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001684 DataType::QAsymmU8,
1685 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001686 };
1687
1688 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1689 "Reference splitter: input type not supported");
1690
1691 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001692}
1693
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001694bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1695 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1696 const ViewsDescriptor& descriptor,
1697 Optional<std::string&> reasonIfUnsupported) const
1698{
1699 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001700 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001701 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001702 {
1703 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001704 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001705 DataType::QAsymmU8,
1706 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001707 };
1708
1709 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1710 "Reference splitter: output type not supported");
1711 for (const TensorInfo output : outputs)
1712 {
1713 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1714 "Reference splitter: input type not supported");
1715
1716 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1717 "Reference splitter: input and output types mismatched.");
1718 }
1719
1720 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001721}
1722
Matthew Jackson81e601c2019-07-11 12:07:09 +01001723bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1724 const TensorInfo& output,
1725 const StackDescriptor& descriptor,
1726 Optional<std::string&> reasonIfUnsupported) const
1727{
1728 ignore_unused(descriptor);
1729
1730 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001731 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001732 {
1733 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001734 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001735 DataType::QAsymmU8,
1736 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01001737 };
1738
1739 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1740 "Reference stack: output type not supported");
1741 for (const TensorInfo* input : inputs)
1742 {
1743 BOOST_ASSERT(input != nullptr);
1744 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1745 "Reference stack: input type not supported");
1746
1747 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1748 "Reference stack: input and output types mismatched.");
1749 }
1750
1751 return supported;
1752}
1753
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001754bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1755 const TensorInfo& output,
1756 const StridedSliceDescriptor& descriptor,
1757 Optional<std::string&> reasonIfUnsupported) const
1758{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001759 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001760 bool supported = true;
1761
1762 std::array<DataType,3> supportedTypes =
1763 {
1764 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001765 DataType::QAsymmU8,
1766 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001767 };
1768
1769 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1770 "Reference StridedSlice: input type not supported");
1771
1772 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1773 "Reference StridedSlice: output type not supported");
1774
1775 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1776 "Reference StridedSlice: input and output types are mismatched");
1777
1778 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001779}
1780
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001781bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1782 const TensorInfo& input1,
1783 const TensorInfo& output,
1784 Optional<std::string&> reasonIfUnsupported) const
1785{
Sadik Armagan2999a022019-04-09 14:20:12 +01001786 bool supported = true;
1787
Matthew Jackson9bff1442019-09-12 09:08:23 +01001788 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001789 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001790 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001791 DataType::QAsymmU8,
1792 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001793 };
1794
1795 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1796 "Reference subtraction: input 0 is not a supported type.");
1797
1798 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1799 "Reference subtraction: input 1 is not a supported type.");
1800
1801 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1802 "Reference subtraction: output is not a supported type.");
1803
1804 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1805 "Reference subtraction: input 0 and Input 1 types are mismatched");
1806
1807 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1808 "Reference subtraction: input and output types are mismatched");
1809
1810 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1811 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1812
1813 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001814}
1815
Matteo Martincighab9e5252019-06-13 17:27:46 +01001816bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1817 const TensorInfo& alpha,
1818 const TensorInfo& output,
1819 Optional<std::string&> reasonIfUnsupported) const
1820{
1821 bool supported = true;
1822
Matthew Jackson9bff1442019-09-12 09:08:23 +01001823 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001824 {
1825 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001826 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001827 DataType::QAsymmU8,
1828 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01001829 };
1830
1831 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1832 "PReLU: input is not a supported type.");
1833
1834 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1835 "PReLU: alpha is not a supported type.");
1836
1837 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1838 "PReLU: output is not a supported type.");
1839
1840 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1841 "PReLU: input, alpha and output types are mismatched");
1842
1843 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1844 "PReLU: shapes are not suitable for implicit broadcast");
1845
1846 return supported;
1847}
1848
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001849bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1850 const TensorInfo& output,
1851 const TransposeConvolution2dDescriptor& descriptor,
1852 const TensorInfo& weights,
1853 const Optional<TensorInfo>& biases,
1854 Optional<std::string&> reasonIfUnsupported) const
1855{
Derek Lamberti901ea112019-12-10 22:07:09 +00001856 boost::ignore_unused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001857 bool supported = true;
1858
Matthew Jackson252df3a2019-09-11 09:19:18 +01001859 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001860 {
1861 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001862 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001863 DataType::QAsymmU8,
1864 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001865 };
1866
1867 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1868 "Reference TransposeConvolution2d: input is not a supported type.");
1869
1870 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1871 "Reference TransposeConvolution2d: output is not a supported type.");
1872
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001873 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1874 "Reference TransposeConvolution2d: input and output types mismatched.");
1875
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001876
1877 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +00001878 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001879 {
Derek Lambertid466a542020-01-22 15:37:29 +00001880 ARMNN_NO_DEPRECATE_WARN_BEGIN
1881 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001882 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001883 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00001884 DataType::QSymmS8,
1885 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001886 };
Derek Lambertid466a542020-01-22 15:37:29 +00001887 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001888
1889 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1890 "Reference TransposeConvolution2d: weights type not supported for "
1891 "quantized input.");
1892 }
1893 else
1894 {
1895 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1896 "Reference TransposeConvolution2d: weights is not a supported type.");
1897
1898 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1899 "Reference TransposeConvolution2d: input and weights types mismatched.");
1900 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001901
1902 if (biases.has_value())
1903 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001904 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001905 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001906 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001907 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001908 DataType::Signed32
1909 };
1910 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1911 "Reference TransposeConvolution2d: biases is not a supported type.");
1912 }
1913
1914 return supported;
1915}
1916
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001917bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
1918 const TensorInfo& output,
1919 const TransposeDescriptor& descriptor,
1920 Optional<std::string&> reasonIfUnsupported) const
1921{
1922 ignore_unused(descriptor);
1923 bool supported = true;
1924
1925 // Define supported output and inputs types.
1926 std::array<DataType, 4> supportedTypes =
1927 {
1928 DataType::Float32,
1929 DataType::Float16,
1930 DataType::QAsymmU8,
1931 DataType::QSymmS16
1932 };
1933
1934 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1935 "Reference transpose: input is not a supported type.");
1936
1937 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1938 "Reference transpose: output is not a supported type.");
1939
1940 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1941 "Reference transpose: input and output types are mismatched.");
1942
1943 return supported;
1944}
1945
arovir011c7c81b2018-10-08 11:34:28 +01001946} // namespace armnn