blob: 5cb36c42991c6ea9ce98cc21b11296975be47355 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
Matteo Martincighe011d202019-11-28 11:35:47 +000012#include <LayerSupportCommon.hpp>
13
Derek Lambertif674aa02019-08-01 15:56:25 +010014#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000015
Matteo Martincighe011d202019-11-28 11:35:47 +000016#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000020#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
James Conroy4d1ff582019-06-10 17:06:39 +010049namespace
50{
51
52std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
53 unsigned int actual,
54 std::string& layerStr,
55 std::string& tensorName)
56{
57 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
58 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
59
60 return errorMsg;
61}
62
63} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000064
Sadik Armagan9199e582019-09-05 17:35:31 +010065bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
66 Optional<std::string&> reasonIfUnsupported) const
67{
josh minor4a3c6102020-01-06 16:40:46 -060068 return IsElementwiseUnarySupported(input,
69 output,
70 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
71 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010072}
73
arovir011c7c81b2018-10-08 11:34:28 +010074bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
75 const TensorInfo& output,
76 const ActivationDescriptor& descriptor,
77 Optional<std::string&> reasonIfUnsupported) const
78{
Derek Lamberti50db4e82019-03-13 14:16:15 +000079 bool supported = true;
80
81 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000082 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +000083 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010084 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000085 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000086 DataType::QAsymmU8,
87 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000088 };
89
90 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
91 "Reference activation: input type not supported.");
92
93 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
94 "Reference activation: output type not supported.");
95
96 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
97 "Reference activation: input and output types mismatched.");
98
99 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
100 "Reference activation: input and output shapes are of different rank.");
101
102
103 struct ActivationFunctionSupported : public Rule
104 {
105 ActivationFunctionSupported(const ActivationDescriptor& desc)
106 {
107 switch(desc.m_Function)
108 {
109 case ActivationFunction::Abs:
110 case ActivationFunction::BoundedReLu:
111 case ActivationFunction::LeakyReLu:
112 case ActivationFunction::Linear:
113 case ActivationFunction::ReLu:
114 case ActivationFunction::Sigmoid:
115 case ActivationFunction::SoftReLu:
116 case ActivationFunction::Sqrt:
117 case ActivationFunction::Square:
118 case ActivationFunction::TanH:
119 {
120 m_Res = true;
121 break;
122 }
123 default:
124 {
125 m_Res = false;
126 break;
127 }
128 }
129 }
130 };
131
132 // Function is supported
133 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
134 "Reference activation: function not supported.");
135
136 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100137}
138
139bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
140 const TensorInfo& input1,
141 const TensorInfo& output,
142 Optional<std::string&> reasonIfUnsupported) const
143{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000144 bool supported = true;
145
Keith Davis0c2eeac2020-02-11 16:51:50 +0000146 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000147 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100148 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000149 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000150 DataType::QAsymmU8,
151 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000152 };
153
154 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
155 "Reference addition: input 0 is not a supported type.");
156
157 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
158 "Reference addition: input 1 is not a supported type.");
159
160 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
161 "Reference addition: output is not a supported type.");
162
163 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
164 "Reference addition: input 0 and Input 1 types are mismatched");
165
166 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
167 "Reference addition: input and output types are mismatched");
168
169 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
170 "Reference addition: shapes are not suitable for implicit broadcast.");
171
172 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100173}
174
Nikhil Raj68c2c902019-09-19 11:21:11 +0100175bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
176 const armnn::ArgMinMaxDescriptor &descriptor,
177 armnn::Optional<std::string &> reasonIfUnsupported) const
178{
179 ignore_unused(descriptor);
180
Francis Murtagh1939df52019-11-13 15:21:09 +0000181 std::array<DataType, 4> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100182 {
183 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000184 DataType::QAsymmU8,
185 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000186 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100187 };
188
189 bool supported = true;
190
191 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
192 "Reference ArgMinMax: input is not a supported type.");
193 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
194 "Reference ArgMinMax: output type not supported");
195
196 return supported;
197}
198
arovir011c7c81b2018-10-08 11:34:28 +0100199bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
200 const TensorInfo& output,
201 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100202 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100203 const TensorInfo& beta,
204 const TensorInfo& gamma,
205 const BatchNormalizationDescriptor& descriptor,
206 Optional<std::string&> reasonIfUnsupported) const
207{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100208 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100209
Matthew Jackson9bff1442019-09-12 09:08:23 +0100210 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100211 {
212 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100213 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000214 DataType::QAsymmU8,
215 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100216 };
217
218 bool supported = true;
219
220 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
221 "Reference batch normalization: input is not a supported type.");
222
223 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
224 "Reference batch normalization: output is not a supported type.");
225
226 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
227 "Reference batch normalization: input and output types are mismatched");
228
229 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
230 "Reference batch normalization: mean is not a supported type.");
231
232 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
233 "Reference batch normalization: variance is not a supported type.");
234
235 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
236 "Reference batch normalization: beta is not a supported type.");
237
238 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
239 "Reference batch normalization: gamma is not a supported type.");
240
241 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100242}
243
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000244bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
245 const TensorInfo& output,
246 const BatchToSpaceNdDescriptor& descriptor,
247 Optional<std::string&> reasonIfUnsupported) const
248{
249 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100250
251 bool supported = true;
252
253 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
254 std::string inputTensorStr = "input";
255 std::string outputTensorStr = "output";
256
257 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100258 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100259 {
260 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100261 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000262 DataType::QAsymmU8,
263 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100264 };
265
266 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
267 "Reference BatchToSpaceNd: input type not supported.");
268
269 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
270 "Reference BatchToSpaceNd: output type not supported.");
271
272 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
273 "Reference BatchToSpaceNd: input and output types mismatched.");
274
275 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
276 reasonIfUnsupported,
277 CreateIncorrectDimensionsErrorMsg(4,
278 output.GetNumDimensions(),
279 batchToSpaceNdLayerStr,
280 outputTensorStr).data());
281
282 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
283 reasonIfUnsupported,
284 CreateIncorrectDimensionsErrorMsg(4,
285 input.GetNumDimensions(),
286 batchToSpaceNdLayerStr,
287 inputTensorStr).data());
288
289 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000290}
291
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100292bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
293 const TensorInfo& input1,
294 const TensorInfo& output,
295 const ComparisonDescriptor& descriptor,
296 Optional<std::string&> reasonIfUnsupported) const
297{
298 boost::ignore_unused(descriptor);
299
300 std::array<DataType, 4> supportedInputTypes =
301 {
302 DataType::Float32,
303 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000304 DataType::QAsymmU8,
305 DataType::QSymmS16
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100306 };
307
308 bool supported = true;
309 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
310 "Reference comparison: input 0 is not a supported type");
311
312 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
313 "Reference comparison: input 0 and Input 1 types are mismatched");
314
315 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
316 "Reference comparison: output is not of type Boolean");
317
318 return supported;
319}
320
Jim Flynn906f9462019-05-10 13:55:21 +0100321bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
322 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100323 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100324 Optional<std::string&> reasonIfUnsupported) const
325{
Jim Flynne242f2d2019-05-22 14:24:13 +0100326 ignore_unused(descriptor);
327
328 bool supported = true;
Keith Davis5204aa82020-01-27 15:24:59 +0000329 std::array<DataType,5> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100330 {
331 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100332 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000333 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +0000334 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000335 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100336 };
337
338 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
339 "Reference concatenation: output type not supported");
340 for (const TensorInfo* input : inputs)
341 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100342 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100343 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
344 "Reference concatenation: input type not supported");
345
346 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
347 "Reference concatenation: input and output types mismatched.");
348 }
349
350 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100351}
352
arovir011c7c81b2018-10-08 11:34:28 +0100353bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
354 Optional<std::string&> reasonIfUnsupported) const
355{
Keith Davis67e6c542020-02-19 10:08:33 +0000356 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100357 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100358 DataType::Float32,
359 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000360 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +0000361 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000362 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000363 DataType::QSymmS16
Nina Drozd58ef2c62019-05-16 12:09:18 +0100364 };
365
366 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
367 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100368}
369
370bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
371 const TensorInfo& output,
372 Optional<std::string&> reasonIfUnsupported) const
373{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100374 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
375 input.GetDataType(),
376 &TrueFunc<>,
377 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000378 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000379 &FalseFuncI32<>,
380 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100381 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
382 output.GetDataType(),
383 &FalseOutputFuncF16<>,
384 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000385 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000386 &FalseFuncI32<>,
387 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100388}
389
390bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
391 const TensorInfo& output,
392 Optional<std::string&> reasonIfUnsupported) const
393{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100394 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
395 input.GetDataType(),
396 &FalseInputFuncF16<>,
397 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000398 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000399 &FalseFuncI32<>,
400 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100401 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
402 output.GetDataType(),
403 &TrueFunc<>,
404 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000405 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000406 &FalseFuncI32<>,
407 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100408}
409
410bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
411 const TensorInfo& output,
412 const Convolution2dDescriptor& descriptor,
413 const TensorInfo& weights,
414 const Optional<TensorInfo>& biases,
415 Optional<std::string&> reasonIfUnsupported) const
416{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100417 bool supported = true;
418
419 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000420 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000421 {
422 DataType::Float32,
423 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000424 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000425 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000426 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000427 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100428 };
429
430 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000431 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100432
433 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000434 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100435
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100436 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000437 "Reference Convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100438
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000439 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000440 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000441 {
Derek Lambertid466a542020-01-22 15:37:29 +0000442 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000443 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000444 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000445 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000446 DataType::QSymmS8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000447 DataType::QAsymmS8,
Derek Lambertid466a542020-01-22 15:37:29 +0000448 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000449 };
Derek Lambertid466a542020-01-22 15:37:29 +0000450 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000451
452 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000453 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000454 }
455 else
456 {
457 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000458 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000459
460 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000461 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000462 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100463
464 if (biases.has_value())
465 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000466 std::array<DataType,3> biasesSupportedTypes =
467 {
468 DataType::Float32,
469 DataType::Float16,
470 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100471 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000472
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100473 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000474 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100475 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100476 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100477
478 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100479}
480
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000481bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
482 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000483 Optional<std::string&> reasonIfUnsupported) const
484{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100485 bool supported = true;
486
Keith Davis0c2eeac2020-02-11 16:51:50 +0000487 std::array<DataType, 7> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100488 {
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000489 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100490 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000491 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000492 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +0000493 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000494 DataType::QSymmS16,
495 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100496 };
497
498 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000499 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100500
501 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000502 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100503
504 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000505 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100506
507 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000508}
509
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100510bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
511 const TensorInfo& output,
512 const DepthToSpaceDescriptor& descriptor,
513 Optional<std::string&> reasonIfUnsupported) const
514{
515 ignore_unused(descriptor);
516 bool supported = true;
517
518 std::array<DataType,4> supportedTypes =
519 {
520 DataType::Float32,
521 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000522 DataType::QAsymmU8,
523 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100524 };
525
526 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
527 "Reference DepthToSpace: input type not supported");
528
529 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
530 "Reference DepthToSpace: output type not supported");
531
532 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
533 "Reference DepthToSpace: input and output types are mismatched");
534
535 return supported;
536}
537
arovir011c7c81b2018-10-08 11:34:28 +0100538bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
539 const TensorInfo& output,
540 const DepthwiseConvolution2dDescriptor& descriptor,
541 const TensorInfo& weights,
542 const Optional<TensorInfo>& biases,
543 Optional<std::string&> reasonIfUnsupported) const
544{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100545 bool supported = true;
546
547 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000548 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100549 {
550 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100551 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000552 DataType::QSymmS8,
553 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000554 DataType::QAsymmU8,
555 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100556 };
557
558 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
559 "Reference DepthwiseConvolution2d: input is not a supported type.");
560
561 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
562 "Reference DepthwiseConvolution2d: output is not a supported type.");
563
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100564 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
565 "Reference DepthwiseConvolution2d: input and output types mismatched.");
566
Derek Lambertid466a542020-01-22 15:37:29 +0000567 ARMNN_NO_DEPRECATE_WARN_BEGIN
568 std::array<DataType, 3> supportedWeightTypes =
569 {
570 DataType::QAsymmU8,
571 DataType::QSymmS8,
572 DataType::QuantizedSymm8PerAxis // deprecated
573 };
574 ARMNN_NO_DEPRECATE_WARN_END
575
Teresa Charlind8df0262019-11-11 12:28:15 +0000576 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000577 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000578 {
Teresa Charlind8df0262019-11-11 12:28:15 +0000579
580 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
581 "Reference convolution2d: weights type not supported for quantized input.");
582 }
583 else
584 {
585 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
586 "Reference DepthwiseConvolution2d: weights is not a supported type.");
587
588 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
589 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
590 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100591
592 if (biases.has_value())
593 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100594 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100595 {
596 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100597 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100598 DataType::Signed32
599 };
600 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
601 "Reference DepthwiseConvolution2d: biases is not a supported type.");
602 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100603 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100604
605 return supported;
606
arovir011c7c81b2018-10-08 11:34:28 +0100607}
608
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000609bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
610 const TensorInfo& output,
611 Optional<std::string&> reasonIfUnsupported) const
612{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100613 bool supported = true;
614
Ryan OShea9add1202020-02-07 10:06:33 +0000615 std::array<DataType,4> supportedInputTypes = {
616 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000617 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000618 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000619 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100620 };
621
622 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000623 "Reference for Dequantize layer: input type not supported.");
624
625 supported &= CheckSupportRule( TypeNotPerAxisQuantized(input), reasonIfUnsupported,
626 "Reference for Dequantize layer: per-axis quantized input not support .");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100627
Derek Lambertid466a542020-01-22 15:37:29 +0000628 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
629 "Reference dequantize: per-axis quantized input not support .");
630
Jan Eilersf7107932019-11-01 11:09:36 +0000631 std::array<DataType,2> supportedOutputTypes = {
632 DataType::Float32,
633 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100634 };
635
636 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000637 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100638
639 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000640 "Reference for Dequantize layer: input/output shapes have different num total "
641 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100642
643 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000644}
645
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000646bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
647 const TensorInfo& scores,
648 const TensorInfo& anchors,
649 const TensorInfo& detectionBoxes,
650 const TensorInfo& detectionClasses,
651 const TensorInfo& detectionScores,
652 const TensorInfo& numDetections,
653 const DetectionPostProcessDescriptor& descriptor,
654 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000655{
Derek Lamberti901ea112019-12-10 22:07:09 +0000656 boost::ignore_unused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
657
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100658 bool supported = true;
659
Mike Kelly4992c342019-08-14 11:33:11 +0100660 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100661 {
662 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000663 DataType::QAsymmU8,
664 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100665 };
666
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000667 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100668 "Reference DetectionPostProcess: input 0 is not a supported type.");
669
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000670 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100671 "Reference DetectionPostProcess: input 1 is not a supported type.");
672
673 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000674}
675
Pablo Tellof0bd6832019-04-26 17:58:13 +0100676bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
677 const TensorInfo& output,
678 const DepthwiseConvolution2dDescriptor& descriptor,
679 const TensorInfo& weights,
680 const Optional<TensorInfo>& biases,
681 Optional<std::string&> reasonIfUnsupported) const
682{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100683 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100684}
685
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100686bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100687 const TensorInfo& input1,
688 const TensorInfo& output,
689 Optional<std::string&> reasonIfUnsupported) const
690{
Sadik Armagan2999a022019-04-09 14:20:12 +0100691 bool supported = true;
692
Matthew Jackson9bff1442019-09-12 09:08:23 +0100693 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100694 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100695 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000696 DataType::QAsymmU8,
697 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +0100698 };
699
700 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
701 "Reference division: input 0 is not a supported type.");
702
703 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
704 "Reference division: input 1 is not a supported type.");
705
706 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
707 "Reference division: output is not a supported type.");
708
709 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
710 "Reference division: input 0 and Input 1 types are mismatched");
711
712 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
713 "Reference division: input and output types are mismatched");
714
715 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
716 "Reference division: shapes are not suitable for implicit broadcast.");
717
718 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100719}
720
josh minor4a3c6102020-01-06 16:40:46 -0600721bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
722 const TensorInfo& output,
723 const ElementwiseUnaryDescriptor& descriptor,
724 Optional<std::string&> reasonIfUnsupported) const
725{
726 boost::ignore_unused(descriptor);
727
728 std::array<DataType, 4> supportedTypes =
729 {
730 DataType::Float32,
731 DataType::Float16,
732 DataType::QAsymmU8,
733 DataType::QSymmS16
734 };
735
736 bool supported = true;
737
738 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
739 "Reference elementwise unary: input type not supported");
740
741 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
742 "Reference elementwise unary: output type not supported");
743
744 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
745 "Reference elementwise unary: input and output types not matching");
746
747 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
748 "Reference elementwise unary: input and output shapes"
749 "have different number of total elements");
750
751 return supported;
752}
753
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000754bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
755 const TensorInfo& input1,
756 const TensorInfo& output,
757 Optional<std::string&> reasonIfUnsupported) const
758{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100759 return IsComparisonSupported(input0,
760 input1,
761 output,
762 ComparisonDescriptor(ComparisonOperation::Equal),
763 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000764}
765
arovir011c7c81b2018-10-08 11:34:28 +0100766bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
767 const FakeQuantizationDescriptor& descriptor,
768 Optional<std::string&> reasonIfUnsupported) const
769{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100770 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100771 bool supported = true;
772
773 std::array<DataType,1> supportedTypes =
774 {
775 DataType::Float32
776 };
777
778 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
779 "Reference fake quantization: input type not supported.");
780
781 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100782}
783
784bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
785 const TensorInfo& output,
786 Optional<std::string&> reasonIfUnsupported) const
787{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100788 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100789 bool supported = true;
790
Matthew Jackson9bff1442019-09-12 09:08:23 +0100791 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100792 {
James Conroyb40d7102019-06-04 12:32:09 +0100793 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100794 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000795 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +0100796 };
797
798 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
799 "Reference Floor: input type not supported.");
800
801 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
802 "Reference Floor: output type not supported.");
803
804 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100805}
806
807bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
808 const TensorInfo& output,
809 const TensorInfo& weights,
810 const TensorInfo& biases,
811 const FullyConnectedDescriptor& descriptor,
812 Optional<std::string&> reasonIfUnsupported) const
813{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100814 bool supported = true;
815
816 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100817 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100818 {
819 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100820 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000821 DataType::QAsymmU8,
822 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100823 };
824
825 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
826 "Reference Fully Connected: input type not supported.");
827
828 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
829 "Reference Fully Connected: output type not supported.");
830
831 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
832 "Reference Fully Connected: input and output types mismatched.");
833
834 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
835 "Reference Fully Connected: weights type not supported.");
836
837 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
838 "Reference Fully Connected: input and weight types mismatched.");
839
840 if (descriptor.m_BiasEnabled)
841 {
842 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100843 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100844 supportedBiasTypes =
845 {
846 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100847 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100848 DataType::Signed32
849 };
850
851 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
852 "Reference Fully Connected: bias type not supported.");
853
854 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
855 "Reference Fully Connected: bias and weight types mismatch.");
856
857 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
858 "Reference Fully Connected: bias type inferred from weights is incompatible.");
859
860 }
861
862 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100863}
864
narpra014951d842019-01-18 16:53:53 +0000865bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
866 const armnn::TensorInfo& input1,
867 const armnn::TensorInfo& output,
868 armnn::Optional<std::string&> reasonIfUnsupported) const
869{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100870 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100871 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100872 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100873 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100874 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000875 DataType::QAsymmU8,
876 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100877 };
878
879 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
880 "Reference Gather: input type not supported");
881
882 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
883 "Reference Gather: output type not supported");
884
885 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
886 "Reference Gather: indices (input1) type not supported");
887
888 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
889 "Reference Gather: input and output types not matching");
890
891 return supported;
narpra014951d842019-01-18 16:53:53 +0000892}
893
FrancisMurtagh878f0232018-12-19 10:56:15 +0000894bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
895 const TensorInfo& input1,
896 const TensorInfo& output,
897 Optional<std::string&> reasonIfUnsupported) const
898{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100899 return IsComparisonSupported(input0,
900 input1,
901 output,
902 ComparisonDescriptor(ComparisonOperation::Greater),
903 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000904}
905
Derek Lamberti901ea112019-12-10 22:07:09 +0000906bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
907 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +0100908{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100909 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100910}
911
Kevin May09ca49c2019-10-09 12:37:34 +0100912bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
913 const TensorInfo& output,
914 const InstanceNormalizationDescriptor& descriptor,
915 Optional<std::string&> reasonIfUnsupported) const
916{
917 ignore_unused(descriptor);
918 // Define supported types
919 std::array<DataType, 4> supportedTypes =
920 {
921 DataType::Float32,
922 DataType::Float16
923 };
924
925 bool supported = true;
926
927 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
928 "Reference Instance Normalization: input type not supported.");
929
930 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
931 "Reference Instance Normalization: output type not supported.");
932
933 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
934 "Reference Instance Normalization: input and output types mismatched.");
935
936 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
937 "Reference Instance Normalization: input and output shapes have different "
938 "num total elements.");
939
940 return supported;
941}
942
arovir011c7c81b2018-10-08 11:34:28 +0100943bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
944 const TensorInfo& output,
945 const L2NormalizationDescriptor& descriptor,
946 Optional<std::string&> reasonIfUnsupported) const
947{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100948 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100949 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100950 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100951 {
952 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100953 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000954 DataType::QAsymmU8,
955 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100956 };
957
958 bool supported = true;
959
960 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
961 "Reference L2normalization: input type not supported.");
962
963 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
964 "Reference L2normalization: output type not supported.");
965
966 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
967 "Reference L2normalization: input and output types mismatched.");
968
969 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
970 "Reference L2normalization: input and output shapes have different "
971 "num total elements.");
972
973 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100974}
975
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100976bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
977 const TensorInfo& output,
978 const LogSoftmaxDescriptor& descriptor,
979 Optional<std::string&> reasonIfUnsupported) const
980{
981 ignore_unused(descriptor);
982
983 std::array<DataType, 2> supportedTypes =
984 {
985 DataType::Float32,
986 DataType::Float16
987 };
988
989 bool supported = true;
990 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
991 "Reference LogSoftmax: input type not supported");
992
993 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
994 "Reference LogSoftmax: output type not supported");
995
996 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
997 "Reference LogSoftmax: input and output types do not match");
998
999 return supported;
1000}
1001
arovir011c7c81b2018-10-08 11:34:28 +01001002bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1003 const TensorInfo& outputStateIn,
1004 const TensorInfo& cellStateIn,
1005 const TensorInfo& scratchBuffer,
1006 const TensorInfo& outputStateOut,
1007 const TensorInfo& cellStateOut,
1008 const TensorInfo& output,
1009 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001010 const LstmInputParamsInfo& paramsInfo,
1011 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001012{
telsoa01c577f2c2018-08-31 09:22:23 +01001013 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +01001014 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001015
1016 bool supported = true;
1017
1018 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001019 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001020 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001021 };
1022
Jan Eilersd01a83c2019-07-03 18:20:40 +01001023 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001024 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1025 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001026 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1027 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001028 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1029 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001030 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1031 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001032 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1033 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001034 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1035 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001036 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1037 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001038 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001039 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001040 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001041 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001042 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001043 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001044 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001045 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001046 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001047 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001048 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001049 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001050 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001051 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001052 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001053 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001054 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001055 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001056 "Reference Lstm: input and OutputGateBias types are mismatched");
1057 if (!descriptor.m_CifgEnabled)
1058 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001059 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001060 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001061 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001062 reasonIfUnsupported,
1063 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001064 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001065 "Reference Lstm: input and InputGateBias types are mismatched");
1066 if (descriptor.m_PeepholeEnabled)
1067 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001068 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001069 reasonIfUnsupported,
1070 "Reference Lstm: input and CellToInputWeights types are mismatched");
1071 }
1072 }
1073 if (descriptor.m_PeepholeEnabled)
1074 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001075 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001076 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001077 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001078 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1079 }
1080 if (descriptor.m_ProjectionEnabled)
1081 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001082 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001083 "Reference Lstm: input and mProjectionWeights types are mismatched");
1084 if (paramsInfo.m_ProjectionBias != nullptr)
1085 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001086 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001087 "Reference Lstm: input and ProjectionBias types are mismatched");
1088 }
1089 }
1090 if (descriptor.m_LayerNormEnabled)
1091 {
1092 if (!descriptor.m_CifgEnabled)
1093 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001094 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001095 reasonIfUnsupported,
1096 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1097 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001098 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001099 reasonIfUnsupported,
1100 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001101 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001102 reasonIfUnsupported,
1103 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001104 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001105 reasonIfUnsupported,
1106 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1107 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001108
1109 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001110}
1111
saoste012df12b32018-11-28 16:57:20 +00001112bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1113 const TensorInfo& input1,
1114 const TensorInfo& output,
1115 Optional<std::string&> reasonIfUnsupported) const
1116{
Sadik Armagan2999a022019-04-09 14:20:12 +01001117 bool supported = true;
1118
Keith Davis5204aa82020-01-27 15:24:59 +00001119 std::array<DataType,5> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001120 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001121 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001122 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001123 DataType::QAsymmU8,
1124 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001125 };
1126
1127 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1128 "Reference maximum: input 0 is not a supported type.");
1129
1130 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1131 "Reference maximum: input 1 is not a supported type.");
1132
1133 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1134 "Reference maximum: output is not a supported type.");
1135
1136 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1137 "Reference maximum: input 0 and Input 1 types are mismatched");
1138
1139 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1140 "Reference maximum: input and output types are mismatched");
1141
1142 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1143 "Reference maximum: shapes are not suitable for implicit broadcast.");
1144
1145 return supported;
saoste012df12b32018-11-28 16:57:20 +00001146}
1147
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001148bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1149 const TensorInfo& output,
1150 const MeanDescriptor& descriptor,
1151 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001152{
James Conroy4d1ff582019-06-10 17:06:39 +01001153 bool supported = true;
1154 std::string meanLayerStr = "Mean";
1155 std::string outputTensorStr = "output";
1156
Matthew Jackson252df3a2019-09-11 09:19:18 +01001157 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001158 {
1159 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001160 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001161 DataType::QAsymmU8,
1162 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001163 };
1164
1165 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1166 "Reference Mean: input type not supported.");
1167
1168 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1169 "Reference Mean: input and output types are mismatched");
1170
1171 if (descriptor.m_KeepDims)
1172 {
1173 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1174 reasonIfUnsupported,
1175 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1176 output.GetNumDimensions(),
1177 meanLayerStr, outputTensorStr).data());
1178 }
1179 else if (descriptor.m_Axis.empty())
1180 {
1181 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1182 reasonIfUnsupported,
1183 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1184 meanLayerStr, outputTensorStr).data());
1185 }
1186 else
1187 {
1188 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1189
1190 if (outputDim > 0)
1191 {
1192 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1193 reasonIfUnsupported,
1194 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1195 meanLayerStr, outputTensorStr).data());
1196 }
1197 else
1198 {
1199 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1200 reasonIfUnsupported,
1201 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1202 meanLayerStr, outputTensorStr).data());
1203 }
1204 }
1205
1206 return supported;
narpra0132b90462018-09-13 11:07:48 +01001207}
1208
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001209bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001210 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001211 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001212 Optional<std::string&> reasonIfUnsupported) const
1213{
Jim Flynne242f2d2019-05-22 14:24:13 +01001214 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001215}
1216
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001217bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1218 const TensorInfo &output,
1219 Optional<std::string &> reasonIfUnsupported) const
1220{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001221 bool supported = true;
1222
1223 std::array<DataType,5> supportedTypes =
1224 {
1225 DataType::Float32,
1226 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001227 DataType::QAsymmU8,
1228 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001229 DataType::Boolean
1230 };
1231
1232 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1233 "Reference MemCopy: input type not supported");
1234
1235 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1236 "Reference MemCopy: output type not supported");
1237
1238 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1239 "Reference MemCopy: input and output types are mismatched");
1240
1241 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001242}
1243
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001244bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1245 const TensorInfo& input1,
1246 const TensorInfo& output,
1247 Optional<std::string&> reasonIfUnsupported) const
1248{
Sadik Armagan2999a022019-04-09 14:20:12 +01001249 bool supported = true;
1250
Matthew Jackson9bff1442019-09-12 09:08:23 +01001251 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001252 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001253 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001254 DataType::QAsymmU8,
1255 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001256 };
1257
1258 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1259 "Reference minimum: input 0 is not a supported type.");
1260
1261 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1262 "Reference minimum: input 1 is not a supported type.");
1263
1264 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1265 "Reference minimum: output is not a supported type.");
1266
1267 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1268 "Reference minimum: input 0 and Input 1 types are mismatched");
1269
1270 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1271 "Reference minimum: input and output types are mismatched");
1272
1273 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1274 "Reference minimum: shapes are not suitable for implicit broadcast.");
1275
1276 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001277}
1278
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001279bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1280 const TensorInfo& input1,
1281 const TensorInfo& output,
1282 Optional<std::string&> reasonIfUnsupported) const
1283{
Sadik Armagan2999a022019-04-09 14:20:12 +01001284 bool supported = true;
1285
Keith Davis67e6c542020-02-19 10:08:33 +00001286 std::array<DataType,6> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001287 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001288 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001289 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001290 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001291 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001292 };
1293
1294 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1295 "Reference multiplication: input 0 is not a supported type.");
1296
1297 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1298 "Reference multiplication: input 1 is not a supported type.");
1299
1300 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1301 "Reference multiplication: output is not a supported type.");
1302
1303 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1304 "Reference multiplication: input 0 and Input 1 types are mismatched");
1305
1306 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1307 "Reference multiplication: input and output types are mismatched");
1308
1309 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1310 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1311
1312 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001313}
1314
1315bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1316 const TensorInfo& output,
1317 const NormalizationDescriptor& descriptor,
1318 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001319{
Nina Drozd661dfa72018-10-02 11:14:17 +01001320 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001321
1322 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001323 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001324 {
1325 DataType::Float16,
1326 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001327 DataType::QAsymmU8,
1328 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001329 };
1330
1331 bool supported = true;
1332
1333 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1334 "Reference normalization: input type not supported.");
1335
1336 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1337 "Reference normalization: output type not supported.");
1338
1339 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1340 "Reference normalization: input and output shapes have different "
1341 "num total elements.");
1342
1343 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001344}
1345
Derek Lamberti901ea112019-12-10 22:07:09 +00001346bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1347 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001348{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001349 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001350}
1351
1352bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1353 const TensorInfo& output,
1354 const PadDescriptor& descriptor,
1355 Optional<std::string&> reasonIfUnsupported) const
1356{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001357 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001358 bool supported = true;
1359
1360 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001361 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001362 {
1363 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001364 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001365 DataType::QAsymmU8,
1366 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001367 };
1368
1369 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1370 "Reference pad: input is not a supported type.");
1371
1372 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1373 "Reference pad: output is not a supported type.");
1374
1375 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1376 "Reference pad: input and output types are mismatched.");
1377
1378 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001379}
1380
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001381bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1382 const TensorInfo& output,
1383 const PermuteDescriptor& descriptor,
1384 Optional<std::string&> reasonIfUnsupported) const
1385{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001386 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001387 bool supported = true;
1388
1389 // Define supported output and inputs types.
1390 std::array<DataType,3> supportedTypes =
1391 {
1392 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001393 DataType::QAsymmU8,
1394 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001395 };
1396
1397 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1398 "Reference permute: input is not a supported type.");
1399
1400 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1401 "Reference permute: output is not a supported type.");
1402
1403 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1404 "Reference permute: input and output types are mismatched.");
1405
1406 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001407}
1408
1409bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1410 const TensorInfo& output,
1411 const Pooling2dDescriptor& descriptor,
1412 Optional<std::string&> reasonIfUnsupported) const
1413{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001414 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001415 bool supported = true;
1416
1417 // Define supported output and inputs types.
Keith Davis67e6c542020-02-19 10:08:33 +00001418 std::array<DataType,5> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001419 {
1420 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001421 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001422 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001423 DataType::QAsymmU8,
1424 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001425 };
1426
1427 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1428 "Reference poolind2d: input is not a supported type.");
1429
1430 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1431 "Reference poolind2d: output is not a supported type.");
1432
1433 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1434 "Reference poolind2d: input and output types are mismatched.");
1435
1436 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001437}
1438
Derek Lamberti5f400d62019-03-25 15:41:58 +00001439bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1440 const TensorInfo& output,
1441 Optional<std::string&> reasonIfUnsupported) const
1442{
1443 bool supported = true;
1444
Finn Williamsfd271062019-12-04 14:27:27 +00001445 // Define supported input types.
Ryan OShea9add1202020-02-07 10:06:33 +00001446 std::array<DataType,6> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00001447 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001448 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001449 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001450 DataType::QAsymmU8,
1451 DataType::QSymmS8,
1452 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001453 };
1454
1455 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1456 "Reference quantize: input type not supported.");
1457
1458 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001459 std::array<DataType,4> supportedOutputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001460 DataType::QAsymmU8,
Ryan OShea9add1202020-02-07 10:06:33 +00001461 DataType::QAsymmS8,
Finn Williamsfd271062019-12-04 14:27:27 +00001462 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001463 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001464 };
1465 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1466 "Reference quantize: output type not supported.");
1467
1468 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1469 "Reference quantize: input and output shapes have different num total elements.");
1470
1471 return supported;
1472}
1473
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001474bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001475 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001476 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001477 Optional<std::string&> reasonIfUnsupported) const
1478{
Kevin Maya023c402019-12-12 17:28:05 +00001479 ignore_unused(output);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001480 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001481 // Define supported output types.
Keith Davis0c2eeac2020-02-11 16:51:50 +00001482 std::array<DataType,7> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001483 {
1484 DataType::Float32,
1485 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001486 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001487 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001488 DataType::QAsymmU8,
1489 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001490 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001491
Nina Drozd2f2778f2019-05-27 10:37:05 +01001492 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1493 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001494}
1495
1496bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001497 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001498 Optional<std::string&> reasonIfUnsupported) const
1499{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001500 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001501 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001502 {
1503 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001504 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001505 DataType::QAsymmU8,
1506 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001507 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001508
1509 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1510 "Reference ResizeBilinear: input type not supported");
1511
1512 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1513 "Reference ResizeBilinear: output type not supported");
1514
1515 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1516 "Reference ResizeBilinear: input and output types not matching");
1517
1518 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001519}
1520
Teresa Charlin970f43b2019-07-01 13:51:07 +01001521bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1522 const TensorInfo& output,
1523 const ResizeDescriptor& descriptor,
1524 Optional<std::string&> reasonIfUnsupported) const
1525{
Derek Lamberti901ea112019-12-10 22:07:09 +00001526 boost::ignore_unused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001527 bool supported = true;
Keith Davis5204aa82020-01-27 15:24:59 +00001528 std::array<DataType,5> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001529 {
1530 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001531 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001532 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001533 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001534 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001535 };
1536
1537 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1538 "Reference Resize: input type not supported");
1539
1540 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1541 "Reference Resize: output type not supported");
1542
1543 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1544 "Reference Resize: input and output types not matching");
1545
1546 return supported;
1547}
1548
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001549bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1550 const TensorInfo& output,
1551 Optional<std::string&> reasonIfUnsupported) const
1552{
josh minor4a3c6102020-01-06 16:40:46 -06001553 return IsElementwiseUnarySupported(input,
1554 output,
1555 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1556 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001557}
1558
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001559bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1560 const TensorInfo& output,
1561 const SliceDescriptor& descriptor,
1562 Optional<std::string&> reasonIfUnsupported) const
1563{
Derek Lamberti901ea112019-12-10 22:07:09 +00001564 boost::ignore_unused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001565 bool supported = true;
1566
1567 std::array<DataType, 3> supportedTypes =
1568 {
1569 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001570 DataType::QAsymmU8,
1571 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001572 };
1573
1574 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1575 "Reference Slice: input type not supported");
1576
1577 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1578 "Reference Slice: output type not supported");
1579
1580 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1581 "Reference Slice: input and output types are mismatched");
1582
1583 return supported;
1584}
1585
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001586bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1587 const TensorInfo& output,
1588 const SoftmaxDescriptor& descriptor,
1589 Optional<std::string&> reasonIfUnsupported) const
1590{
Derek Lamberti901ea112019-12-10 22:07:09 +00001591 boost::ignore_unused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001592 bool supported = true;
Keith Davis0c2eeac2020-02-11 16:51:50 +00001593 std::array<DataType,6> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001594 {
1595 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001596 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001597 DataType::QSymmS8,
1598 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001599 DataType::QAsymmU8,
1600 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001601 };
1602
1603 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001604 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001605
1606 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001607 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001608
1609 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001610 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001611
1612 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001613}
1614
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001615bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1616 const TensorInfo& output,
1617 const SpaceToBatchNdDescriptor& descriptor,
1618 Optional<std::string&> reasonIfUnsupported) const
1619{
Derek Lamberti901ea112019-12-10 22:07:09 +00001620 boost::ignore_unused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001621 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001622 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001623 {
1624 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001625 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001626 DataType::QAsymmU8,
1627 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001628 };
1629
1630 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1631 "Reference SpaceToBatchNd: input type not supported");
1632
1633 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1634 "Reference SpaceToBatchNd: output type not supported");
1635
1636 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1637 "Reference SpaceToBatchNd: input and output types are mismatched");
1638
1639 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001640}
1641
Keith Davisa57eccb2019-06-14 17:33:22 +01001642bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001643 const TensorInfo& output,
1644 const SpaceToDepthDescriptor& descriptor,
1645 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001646{
1647
1648 ignore_unused(descriptor);
1649 bool supported = true;
1650
Matthew Jackson9bff1442019-09-12 09:08:23 +01001651 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001652 {
1653 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001654 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001655 DataType::QAsymmU8,
1656 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001657 };
1658
1659 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1660 "Reference SpaceToDepth: input type not supported");
1661
1662 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1663 "Reference SpaceToDepth: output type not supported");
1664
1665 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1666 "Reference SpaceToDepth: input and output types are mismatched");
1667
1668 return supported;
1669}
1670
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001671bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1672 const ViewsDescriptor& descriptor,
1673 Optional<std::string&> reasonIfUnsupported) const
1674{
1675 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001676 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001677 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001678 {
1679 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001680 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001681 DataType::QAsymmU8,
1682 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001683 };
1684
1685 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1686 "Reference splitter: input type not supported");
1687
1688 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001689}
1690
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001691bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1692 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1693 const ViewsDescriptor& descriptor,
1694 Optional<std::string&> reasonIfUnsupported) const
1695{
1696 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001697 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001698 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001699 {
1700 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001701 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001702 DataType::QAsymmU8,
1703 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001704 };
1705
1706 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1707 "Reference splitter: output type not supported");
1708 for (const TensorInfo output : outputs)
1709 {
1710 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1711 "Reference splitter: input type not supported");
1712
1713 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1714 "Reference splitter: input and output types mismatched.");
1715 }
1716
1717 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001718}
1719
Matthew Jackson81e601c2019-07-11 12:07:09 +01001720bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1721 const TensorInfo& output,
1722 const StackDescriptor& descriptor,
1723 Optional<std::string&> reasonIfUnsupported) const
1724{
1725 ignore_unused(descriptor);
1726
1727 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001728 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001729 {
1730 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001731 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001732 DataType::QAsymmU8,
1733 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01001734 };
1735
1736 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1737 "Reference stack: output type not supported");
1738 for (const TensorInfo* input : inputs)
1739 {
1740 BOOST_ASSERT(input != nullptr);
1741 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1742 "Reference stack: input type not supported");
1743
1744 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1745 "Reference stack: input and output types mismatched.");
1746 }
1747
1748 return supported;
1749}
1750
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001751bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1752 const TensorInfo& output,
1753 const StridedSliceDescriptor& descriptor,
1754 Optional<std::string&> reasonIfUnsupported) const
1755{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001756 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001757 bool supported = true;
1758
1759 std::array<DataType,3> supportedTypes =
1760 {
1761 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001762 DataType::QAsymmU8,
1763 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001764 };
1765
1766 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1767 "Reference StridedSlice: input type not supported");
1768
1769 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1770 "Reference StridedSlice: output type not supported");
1771
1772 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1773 "Reference StridedSlice: input and output types are mismatched");
1774
1775 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001776}
1777
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001778bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1779 const TensorInfo& input1,
1780 const TensorInfo& output,
1781 Optional<std::string&> reasonIfUnsupported) const
1782{
Sadik Armagan2999a022019-04-09 14:20:12 +01001783 bool supported = true;
1784
Matthew Jackson9bff1442019-09-12 09:08:23 +01001785 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001786 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001787 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001788 DataType::QAsymmU8,
1789 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001790 };
1791
1792 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1793 "Reference subtraction: input 0 is not a supported type.");
1794
1795 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1796 "Reference subtraction: input 1 is not a supported type.");
1797
1798 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1799 "Reference subtraction: output is not a supported type.");
1800
1801 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1802 "Reference subtraction: input 0 and Input 1 types are mismatched");
1803
1804 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1805 "Reference subtraction: input and output types are mismatched");
1806
1807 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1808 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1809
1810 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001811}
1812
Matteo Martincighab9e5252019-06-13 17:27:46 +01001813bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1814 const TensorInfo& alpha,
1815 const TensorInfo& output,
1816 Optional<std::string&> reasonIfUnsupported) const
1817{
1818 bool supported = true;
1819
Matthew Jackson9bff1442019-09-12 09:08:23 +01001820 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001821 {
1822 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001823 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001824 DataType::QAsymmU8,
1825 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01001826 };
1827
1828 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1829 "PReLU: input is not a supported type.");
1830
1831 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1832 "PReLU: alpha is not a supported type.");
1833
1834 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1835 "PReLU: output is not a supported type.");
1836
1837 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1838 "PReLU: input, alpha and output types are mismatched");
1839
1840 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1841 "PReLU: shapes are not suitable for implicit broadcast");
1842
1843 return supported;
1844}
1845
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001846bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1847 const TensorInfo& output,
1848 const TransposeConvolution2dDescriptor& descriptor,
1849 const TensorInfo& weights,
1850 const Optional<TensorInfo>& biases,
1851 Optional<std::string&> reasonIfUnsupported) const
1852{
Derek Lamberti901ea112019-12-10 22:07:09 +00001853 boost::ignore_unused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001854 bool supported = true;
1855
Matthew Jackson252df3a2019-09-11 09:19:18 +01001856 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001857 {
1858 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001859 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001860 DataType::QAsymmU8,
1861 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001862 };
1863
1864 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1865 "Reference TransposeConvolution2d: input is not a supported type.");
1866
1867 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1868 "Reference TransposeConvolution2d: output is not a supported type.");
1869
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001870 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1871 "Reference TransposeConvolution2d: input and output types mismatched.");
1872
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001873
1874 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +00001875 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001876 {
Derek Lambertid466a542020-01-22 15:37:29 +00001877 ARMNN_NO_DEPRECATE_WARN_BEGIN
1878 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001879 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001880 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00001881 DataType::QSymmS8,
1882 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001883 };
Derek Lambertid466a542020-01-22 15:37:29 +00001884 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001885
1886 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1887 "Reference TransposeConvolution2d: weights type not supported for "
1888 "quantized input.");
1889 }
1890 else
1891 {
1892 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1893 "Reference TransposeConvolution2d: weights is not a supported type.");
1894
1895 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1896 "Reference TransposeConvolution2d: input and weights types mismatched.");
1897 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001898
1899 if (biases.has_value())
1900 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001901 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001902 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001903 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001904 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001905 DataType::Signed32
1906 };
1907 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1908 "Reference TransposeConvolution2d: biases is not a supported type.");
1909 }
1910
1911 return supported;
1912}
1913
arovir011c7c81b2018-10-08 11:34:28 +01001914} // namespace armnn