blob: b801f707248bebd0d3c76e9e5e0fd587477da608 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000011#include <armnn/BackendRegistry.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matteo Martincighe011d202019-11-28 11:35:47 +000013#include <armnnUtils/DataLayoutIndexed.hpp>
14
15#include <InternalTypes.hpp>
16#include <LayerSupportCommon.hpp>
17
Derek Lambertif674aa02019-08-01 15:56:25 +010018#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000019
Matteo Martincighe011d202019-11-28 11:35:47 +000020#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000022
Derek Lamberti50db4e82019-03-13 14:16:15 +000023#include <vector>
24#include <algorithm>
25#include <array>
26
telsoa014fcda012018-03-09 14:13:49 +000027using namespace boost;
28
29namespace armnn
30{
31
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010032namespace
33{
34
35template<typename Float32Func, typename Uint8Func, typename ... Params>
36bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
37 DataType dataType,
38 Float32Func floatFuncPtr,
39 Uint8Func uint8FuncPtr,
40 Params&&... params)
41{
42 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
43 dataType,
44 &FalseFunc<Params...>,
45 floatFuncPtr,
46 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000047 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000048 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010049 std::forward<Params>(params)...);
50}
51
52} // anonymous namespace
53
James Conroy4d1ff582019-06-10 17:06:39 +010054namespace
55{
56
57std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
58 unsigned int actual,
59 std::string& layerStr,
60 std::string& tensorName)
61{
62 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
63 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
64
65 return errorMsg;
66}
67
68} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000069
Sadik Armagan9199e582019-09-05 17:35:31 +010070bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
71 Optional<std::string&> reasonIfUnsupported) const
72{
josh minor4a3c6102020-01-06 16:40:46 -060073 return IsElementwiseUnarySupported(input,
74 output,
75 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
76 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010077}
78
arovir011c7c81b2018-10-08 11:34:28 +010079bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
80 const TensorInfo& output,
81 const ActivationDescriptor& descriptor,
82 Optional<std::string&> reasonIfUnsupported) const
83{
Derek Lamberti50db4e82019-03-13 14:16:15 +000084 bool supported = true;
85
86 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +010087 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +000088 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010089 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +000090 DataType::QAsymmU8,
91 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000092 };
93
94 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
95 "Reference activation: input type not supported.");
96
97 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
98 "Reference activation: output type not supported.");
99
100 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
101 "Reference activation: input and output types mismatched.");
102
103 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
104 "Reference activation: input and output shapes are of different rank.");
105
106
107 struct ActivationFunctionSupported : public Rule
108 {
109 ActivationFunctionSupported(const ActivationDescriptor& desc)
110 {
111 switch(desc.m_Function)
112 {
113 case ActivationFunction::Abs:
114 case ActivationFunction::BoundedReLu:
115 case ActivationFunction::LeakyReLu:
116 case ActivationFunction::Linear:
117 case ActivationFunction::ReLu:
118 case ActivationFunction::Sigmoid:
119 case ActivationFunction::SoftReLu:
120 case ActivationFunction::Sqrt:
121 case ActivationFunction::Square:
122 case ActivationFunction::TanH:
123 {
124 m_Res = true;
125 break;
126 }
127 default:
128 {
129 m_Res = false;
130 break;
131 }
132 }
133 }
134 };
135
136 // Function is supported
137 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
138 "Reference activation: function not supported.");
139
140 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100141}
142
143bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
144 const TensorInfo& input1,
145 const TensorInfo& output,
146 Optional<std::string&> reasonIfUnsupported) const
147{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000148 bool supported = true;
149
Matthew Jackson252df3a2019-09-11 09:19:18 +0100150 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000151 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100152 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000153 DataType::QAsymmU8,
154 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000155 };
156
157 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
158 "Reference addition: input 0 is not a supported type.");
159
160 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
161 "Reference addition: input 1 is not a supported type.");
162
163 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
164 "Reference addition: output is not a supported type.");
165
166 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
167 "Reference addition: input 0 and Input 1 types are mismatched");
168
169 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
170 "Reference addition: input and output types are mismatched");
171
172 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
173 "Reference addition: shapes are not suitable for implicit broadcast.");
174
175 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100176}
177
Nikhil Raj68c2c902019-09-19 11:21:11 +0100178bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
179 const armnn::ArgMinMaxDescriptor &descriptor,
180 armnn::Optional<std::string &> reasonIfUnsupported) const
181{
182 ignore_unused(descriptor);
183
Francis Murtagh1939df52019-11-13 15:21:09 +0000184 std::array<DataType, 4> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100185 {
186 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000187 DataType::QAsymmU8,
188 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000189 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100190 };
191
192 bool supported = true;
193
194 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
195 "Reference ArgMinMax: input is not a supported type.");
196 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
197 "Reference ArgMinMax: output type not supported");
198
199 return supported;
200}
201
arovir011c7c81b2018-10-08 11:34:28 +0100202bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
203 const TensorInfo& output,
204 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100205 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100206 const TensorInfo& beta,
207 const TensorInfo& gamma,
208 const BatchNormalizationDescriptor& descriptor,
209 Optional<std::string&> reasonIfUnsupported) const
210{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100211 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100212
Matthew Jackson9bff1442019-09-12 09:08:23 +0100213 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100214 {
215 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100216 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000217 DataType::QAsymmU8,
218 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100219 };
220
221 bool supported = true;
222
223 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
224 "Reference batch normalization: input is not a supported type.");
225
226 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
227 "Reference batch normalization: output is not a supported type.");
228
229 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
230 "Reference batch normalization: input and output types are mismatched");
231
232 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
233 "Reference batch normalization: mean is not a supported type.");
234
235 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
236 "Reference batch normalization: variance is not a supported type.");
237
238 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
239 "Reference batch normalization: beta is not a supported type.");
240
241 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
242 "Reference batch normalization: gamma is not a supported type.");
243
244 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100245}
246
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000247bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
248 const TensorInfo& output,
249 const BatchToSpaceNdDescriptor& descriptor,
250 Optional<std::string&> reasonIfUnsupported) const
251{
252 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100253
254 bool supported = true;
255
256 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
257 std::string inputTensorStr = "input";
258 std::string outputTensorStr = "output";
259
260 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100261 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100262 {
263 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100264 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000265 DataType::QAsymmU8,
266 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100267 };
268
269 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
270 "Reference BatchToSpaceNd: input type not supported.");
271
272 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
273 "Reference BatchToSpaceNd: output type not supported.");
274
275 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
276 "Reference BatchToSpaceNd: input and output types mismatched.");
277
278 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
279 reasonIfUnsupported,
280 CreateIncorrectDimensionsErrorMsg(4,
281 output.GetNumDimensions(),
282 batchToSpaceNdLayerStr,
283 outputTensorStr).data());
284
285 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
286 reasonIfUnsupported,
287 CreateIncorrectDimensionsErrorMsg(4,
288 input.GetNumDimensions(),
289 batchToSpaceNdLayerStr,
290 inputTensorStr).data());
291
292 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000293}
294
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100295bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
296 const TensorInfo& input1,
297 const TensorInfo& output,
298 const ComparisonDescriptor& descriptor,
299 Optional<std::string&> reasonIfUnsupported) const
300{
301 boost::ignore_unused(descriptor);
302
303 std::array<DataType, 4> supportedInputTypes =
304 {
305 DataType::Float32,
306 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000307 DataType::QAsymmU8,
308 DataType::QSymmS16
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100309 };
310
311 bool supported = true;
312 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
313 "Reference comparison: input 0 is not a supported type");
314
315 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
316 "Reference comparison: input 0 and Input 1 types are mismatched");
317
318 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
319 "Reference comparison: output is not of type Boolean");
320
321 return supported;
322}
323
Jim Flynn906f9462019-05-10 13:55:21 +0100324bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
325 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100326 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100327 Optional<std::string&> reasonIfUnsupported) const
328{
Jim Flynne242f2d2019-05-22 14:24:13 +0100329 ignore_unused(descriptor);
330
331 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100332 std::array<DataType,4> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100333 {
334 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100335 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000336 DataType::QAsymmU8,
337 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100338 };
339
340 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
341 "Reference concatenation: output type not supported");
342 for (const TensorInfo* input : inputs)
343 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100344 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100345 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
346 "Reference concatenation: input type not supported");
347
348 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
349 "Reference concatenation: input and output types mismatched.");
350 }
351
352 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100353}
354
arovir011c7c81b2018-10-08 11:34:28 +0100355bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
356 Optional<std::string&> reasonIfUnsupported) const
357{
Jim Flynne242f2d2019-05-22 14:24:13 +0100358 std::array<DataType,4> supportedTypes =
359 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100360 DataType::Float32,
361 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000362 DataType::QAsymmU8,
363 DataType::QSymmS16
Nina Drozd58ef2c62019-05-16 12:09:18 +0100364 };
365
366 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
367 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100368}
369
370bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
371 const TensorInfo& output,
372 Optional<std::string&> reasonIfUnsupported) const
373{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100374 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
375 input.GetDataType(),
376 &TrueFunc<>,
377 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000378 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000379 &FalseFuncI32<>,
380 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100381 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
382 output.GetDataType(),
383 &FalseOutputFuncF16<>,
384 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000385 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000386 &FalseFuncI32<>,
387 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100388}
389
390bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
391 const TensorInfo& output,
392 Optional<std::string&> reasonIfUnsupported) const
393{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100394 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
395 input.GetDataType(),
396 &FalseInputFuncF16<>,
397 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000398 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000399 &FalseFuncI32<>,
400 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100401 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
402 output.GetDataType(),
403 &TrueFunc<>,
404 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000405 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000406 &FalseFuncI32<>,
407 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100408}
409
410bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
411 const TensorInfo& output,
412 const Convolution2dDescriptor& descriptor,
413 const TensorInfo& weights,
414 const Optional<TensorInfo>& biases,
415 Optional<std::string&> reasonIfUnsupported) const
416{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100417 bool supported = true;
418
419 // Define supported types.
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000420 std::array<DataType,4> supportedTypes =
421 {
422 DataType::Float32,
423 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000424 DataType::QAsymmU8,
425 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100426 };
427
428 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100429 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100430
431 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100432 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100433
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100434 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100435 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100436
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000437 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +0000438 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000439 {
Derek Lambertid466a542020-01-22 15:37:29 +0000440 ARMNN_NO_DEPRECATE_WARN_BEGIN
441 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000442 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000443 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000444 DataType::QSymmS8,
445 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000446 };
Derek Lambertid466a542020-01-22 15:37:29 +0000447 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000448
449 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
450 "Reference convolution2d: weights type not supported for quantized input.");
451 }
452 else
453 {
454 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
455 "Reference convolution2d: weights is not a supported type.");
456
457 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
458 "Reference convolution2d: input and weights types mismatched.");
459 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100460
461 if (biases.has_value())
462 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000463 std::array<DataType,3> biasesSupportedTypes =
464 {
465 DataType::Float32,
466 DataType::Float16,
467 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100468 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000469
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100470 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100471 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100472 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100473 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100474
475 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100476}
477
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000478bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
479 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000480 Optional<std::string&> reasonIfUnsupported) const
481{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100482 bool supported = true;
483
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000484 std::array<DataType, 5> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100485 {
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000486 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100487 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000488 DataType::QAsymmU8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000489 DataType::QSymmS16,
490 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100491 };
492
493 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
494 "Reference debug: input type not supported");
495
496 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
497 "Reference debug: output type not supported");
498
499 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
500 "Reference debug: input and output types are mismatched");
501
502 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000503}
504
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100505bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
506 const TensorInfo& output,
507 const DepthToSpaceDescriptor& descriptor,
508 Optional<std::string&> reasonIfUnsupported) const
509{
510 ignore_unused(descriptor);
511 bool supported = true;
512
513 std::array<DataType,4> supportedTypes =
514 {
515 DataType::Float32,
516 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000517 DataType::QAsymmU8,
518 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100519 };
520
521 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
522 "Reference DepthToSpace: input type not supported");
523
524 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
525 "Reference DepthToSpace: output type not supported");
526
527 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
528 "Reference DepthToSpace: input and output types are mismatched");
529
530 return supported;
531}
532
arovir011c7c81b2018-10-08 11:34:28 +0100533bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
534 const TensorInfo& output,
535 const DepthwiseConvolution2dDescriptor& descriptor,
536 const TensorInfo& weights,
537 const Optional<TensorInfo>& biases,
538 Optional<std::string&> reasonIfUnsupported) const
539{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100540 bool supported = true;
541
542 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100543 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100544 {
545 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100546 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000547 DataType::QAsymmU8,
548 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100549 };
550
551 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
552 "Reference DepthwiseConvolution2d: input is not a supported type.");
553
554 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
555 "Reference DepthwiseConvolution2d: output is not a supported type.");
556
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100557 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
558 "Reference DepthwiseConvolution2d: input and output types mismatched.");
559
Derek Lambertid466a542020-01-22 15:37:29 +0000560 ARMNN_NO_DEPRECATE_WARN_BEGIN
561 std::array<DataType, 3> supportedWeightTypes =
562 {
563 DataType::QAsymmU8,
564 DataType::QSymmS8,
565 DataType::QuantizedSymm8PerAxis // deprecated
566 };
567 ARMNN_NO_DEPRECATE_WARN_END
568
Teresa Charlind8df0262019-11-11 12:28:15 +0000569 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +0000570 if (inputType == DataType::QAsymmU8)
Teresa Charlind8df0262019-11-11 12:28:15 +0000571 {
Teresa Charlind8df0262019-11-11 12:28:15 +0000572
573 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
574 "Reference convolution2d: weights type not supported for quantized input.");
575 }
576 else
577 {
578 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
579 "Reference DepthwiseConvolution2d: weights is not a supported type.");
580
581 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
582 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
583 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100584
585 if (biases.has_value())
586 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100587 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100588 {
589 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100590 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100591 DataType::Signed32
592 };
593 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
594 "Reference DepthwiseConvolution2d: biases is not a supported type.");
595 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100596 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100597
598 return supported;
599
arovir011c7c81b2018-10-08 11:34:28 +0100600}
601
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000602bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
603 const TensorInfo& output,
604 Optional<std::string&> reasonIfUnsupported) const
605{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100606 bool supported = true;
607
Finn Williamsfd271062019-12-04 14:27:27 +0000608 std::array<DataType,3> supportedInputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000609 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000610 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000611 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100612 };
613
614 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
615 "Reference dequantize: input type not supported.");
616
Derek Lambertid466a542020-01-22 15:37:29 +0000617 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
618 "Reference dequantize: per-axis quantized input not support .");
619
Jan Eilersf7107932019-11-01 11:09:36 +0000620 std::array<DataType,2> supportedOutputTypes = {
621 DataType::Float32,
622 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100623 };
624
625 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
626 "Reference dequantize: output type not supported.");
627
628 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
629 "Reference dequantize: input and output shapes have different num total elements.");
630
631 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000632}
633
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000634bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
635 const TensorInfo& scores,
636 const TensorInfo& anchors,
637 const TensorInfo& detectionBoxes,
638 const TensorInfo& detectionClasses,
639 const TensorInfo& detectionScores,
640 const TensorInfo& numDetections,
641 const DetectionPostProcessDescriptor& descriptor,
642 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000643{
Derek Lamberti901ea112019-12-10 22:07:09 +0000644 boost::ignore_unused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
645
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100646 bool supported = true;
647
Mike Kelly4992c342019-08-14 11:33:11 +0100648 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100649 {
650 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000651 DataType::QAsymmU8,
652 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100653 };
654
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000655 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100656 "Reference DetectionPostProcess: input 0 is not a supported type.");
657
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000658 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100659 "Reference DetectionPostProcess: input 1 is not a supported type.");
660
661 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000662}
663
Pablo Tellof0bd6832019-04-26 17:58:13 +0100664bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
665 const TensorInfo& output,
666 const DepthwiseConvolution2dDescriptor& descriptor,
667 const TensorInfo& weights,
668 const Optional<TensorInfo>& biases,
669 Optional<std::string&> reasonIfUnsupported) const
670{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100671 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100672}
673
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100674bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100675 const TensorInfo& input1,
676 const TensorInfo& output,
677 Optional<std::string&> reasonIfUnsupported) const
678{
Sadik Armagan2999a022019-04-09 14:20:12 +0100679 bool supported = true;
680
Matthew Jackson9bff1442019-09-12 09:08:23 +0100681 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100682 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100683 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000684 DataType::QAsymmU8,
685 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +0100686 };
687
688 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
689 "Reference division: input 0 is not a supported type.");
690
691 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
692 "Reference division: input 1 is not a supported type.");
693
694 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
695 "Reference division: output is not a supported type.");
696
697 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
698 "Reference division: input 0 and Input 1 types are mismatched");
699
700 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
701 "Reference division: input and output types are mismatched");
702
703 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
704 "Reference division: shapes are not suitable for implicit broadcast.");
705
706 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100707}
708
josh minor4a3c6102020-01-06 16:40:46 -0600709bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
710 const TensorInfo& output,
711 const ElementwiseUnaryDescriptor& descriptor,
712 Optional<std::string&> reasonIfUnsupported) const
713{
714 boost::ignore_unused(descriptor);
715
716 std::array<DataType, 4> supportedTypes =
717 {
718 DataType::Float32,
719 DataType::Float16,
720 DataType::QAsymmU8,
721 DataType::QSymmS16
722 };
723
724 bool supported = true;
725
726 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
727 "Reference elementwise unary: input type not supported");
728
729 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
730 "Reference elementwise unary: output type not supported");
731
732 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
733 "Reference elementwise unary: input and output types not matching");
734
735 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
736 "Reference elementwise unary: input and output shapes"
737 "have different number of total elements");
738
739 return supported;
740}
741
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000742bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
743 const TensorInfo& input1,
744 const TensorInfo& output,
745 Optional<std::string&> reasonIfUnsupported) const
746{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100747 return IsComparisonSupported(input0,
748 input1,
749 output,
750 ComparisonDescriptor(ComparisonOperation::Equal),
751 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000752}
753
arovir011c7c81b2018-10-08 11:34:28 +0100754bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
755 const FakeQuantizationDescriptor& descriptor,
756 Optional<std::string&> reasonIfUnsupported) const
757{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100758 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100759 bool supported = true;
760
761 std::array<DataType,1> supportedTypes =
762 {
763 DataType::Float32
764 };
765
766 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
767 "Reference fake quantization: input type not supported.");
768
769 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100770}
771
772bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
773 const TensorInfo& output,
774 Optional<std::string&> reasonIfUnsupported) const
775{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100776 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100777 bool supported = true;
778
Matthew Jackson9bff1442019-09-12 09:08:23 +0100779 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100780 {
James Conroyb40d7102019-06-04 12:32:09 +0100781 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100782 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000783 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +0100784 };
785
786 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
787 "Reference Floor: input type not supported.");
788
789 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
790 "Reference Floor: output type not supported.");
791
792 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100793}
794
795bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
796 const TensorInfo& output,
797 const TensorInfo& weights,
798 const TensorInfo& biases,
799 const FullyConnectedDescriptor& descriptor,
800 Optional<std::string&> reasonIfUnsupported) const
801{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100802 bool supported = true;
803
804 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100805 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100806 {
807 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100808 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000809 DataType::QAsymmU8,
810 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100811 };
812
813 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
814 "Reference Fully Connected: input type not supported.");
815
816 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
817 "Reference Fully Connected: output type not supported.");
818
819 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
820 "Reference Fully Connected: input and output types mismatched.");
821
822 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
823 "Reference Fully Connected: weights type not supported.");
824
825 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
826 "Reference Fully Connected: input and weight types mismatched.");
827
828 if (descriptor.m_BiasEnabled)
829 {
830 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100831 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100832 supportedBiasTypes =
833 {
834 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100835 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100836 DataType::Signed32
837 };
838
839 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
840 "Reference Fully Connected: bias type not supported.");
841
842 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
843 "Reference Fully Connected: bias and weight types mismatch.");
844
845 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
846 "Reference Fully Connected: bias type inferred from weights is incompatible.");
847
848 }
849
850 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100851}
852
narpra014951d842019-01-18 16:53:53 +0000853bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
854 const armnn::TensorInfo& input1,
855 const armnn::TensorInfo& output,
856 armnn::Optional<std::string&> reasonIfUnsupported) const
857{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100858 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100859 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100860 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100861 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100862 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000863 DataType::QAsymmU8,
864 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100865 };
866
867 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
868 "Reference Gather: input type not supported");
869
870 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
871 "Reference Gather: output type not supported");
872
873 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
874 "Reference Gather: indices (input1) type not supported");
875
876 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
877 "Reference Gather: input and output types not matching");
878
879 return supported;
narpra014951d842019-01-18 16:53:53 +0000880}
881
FrancisMurtagh878f0232018-12-19 10:56:15 +0000882bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
883 const TensorInfo& input1,
884 const TensorInfo& output,
885 Optional<std::string&> reasonIfUnsupported) const
886{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100887 return IsComparisonSupported(input0,
888 input1,
889 output,
890 ComparisonDescriptor(ComparisonOperation::Greater),
891 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000892}
893
Derek Lamberti901ea112019-12-10 22:07:09 +0000894bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
895 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +0100896{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100897 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100898}
899
Kevin May09ca49c2019-10-09 12:37:34 +0100900bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
901 const TensorInfo& output,
902 const InstanceNormalizationDescriptor& descriptor,
903 Optional<std::string&> reasonIfUnsupported) const
904{
905 ignore_unused(descriptor);
906 // Define supported types
907 std::array<DataType, 4> supportedTypes =
908 {
909 DataType::Float32,
910 DataType::Float16
911 };
912
913 bool supported = true;
914
915 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
916 "Reference Instance Normalization: input type not supported.");
917
918 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
919 "Reference Instance Normalization: output type not supported.");
920
921 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
922 "Reference Instance Normalization: input and output types mismatched.");
923
924 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
925 "Reference Instance Normalization: input and output shapes have different "
926 "num total elements.");
927
928 return supported;
929}
930
arovir011c7c81b2018-10-08 11:34:28 +0100931bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
932 const TensorInfo& output,
933 const L2NormalizationDescriptor& descriptor,
934 Optional<std::string&> reasonIfUnsupported) const
935{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100936 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100937 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100938 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100939 {
940 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100941 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000942 DataType::QAsymmU8,
943 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100944 };
945
946 bool supported = true;
947
948 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
949 "Reference L2normalization: input type not supported.");
950
951 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
952 "Reference L2normalization: output type not supported.");
953
954 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
955 "Reference L2normalization: input and output types mismatched.");
956
957 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
958 "Reference L2normalization: input and output shapes have different "
959 "num total elements.");
960
961 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100962}
963
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100964bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
965 const TensorInfo& output,
966 const LogSoftmaxDescriptor& descriptor,
967 Optional<std::string&> reasonIfUnsupported) const
968{
969 ignore_unused(descriptor);
970
971 std::array<DataType, 2> supportedTypes =
972 {
973 DataType::Float32,
974 DataType::Float16
975 };
976
977 bool supported = true;
978 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
979 "Reference LogSoftmax: input type not supported");
980
981 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
982 "Reference LogSoftmax: output type not supported");
983
984 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
985 "Reference LogSoftmax: input and output types do not match");
986
987 return supported;
988}
989
arovir011c7c81b2018-10-08 11:34:28 +0100990bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
991 const TensorInfo& outputStateIn,
992 const TensorInfo& cellStateIn,
993 const TensorInfo& scratchBuffer,
994 const TensorInfo& outputStateOut,
995 const TensorInfo& cellStateOut,
996 const TensorInfo& output,
997 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100998 const LstmInputParamsInfo& paramsInfo,
999 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001000{
telsoa01c577f2c2018-08-31 09:22:23 +01001001 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +01001002 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001003
1004 bool supported = true;
1005
1006 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001007 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001008 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001009 };
1010
Jan Eilersd01a83c2019-07-03 18:20:40 +01001011 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001012 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1013 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001014 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1015 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001016 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1017 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001018 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1019 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001020 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1021 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001022 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1023 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001024 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1025 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001026 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001027 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001028 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001029 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001030 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001031 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001032 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001033 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001034 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001035 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001036 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001037 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001038 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001039 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001040 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001041 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001042 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001043 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001044 "Reference Lstm: input and OutputGateBias types are mismatched");
1045 if (!descriptor.m_CifgEnabled)
1046 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001047 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001048 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001049 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001050 reasonIfUnsupported,
1051 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001052 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001053 "Reference Lstm: input and InputGateBias types are mismatched");
1054 if (descriptor.m_PeepholeEnabled)
1055 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001056 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001057 reasonIfUnsupported,
1058 "Reference Lstm: input and CellToInputWeights types are mismatched");
1059 }
1060 }
1061 if (descriptor.m_PeepholeEnabled)
1062 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001063 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001064 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001065 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001066 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1067 }
1068 if (descriptor.m_ProjectionEnabled)
1069 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001070 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001071 "Reference Lstm: input and mProjectionWeights types are mismatched");
1072 if (paramsInfo.m_ProjectionBias != nullptr)
1073 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001074 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001075 "Reference Lstm: input and ProjectionBias types are mismatched");
1076 }
1077 }
1078 if (descriptor.m_LayerNormEnabled)
1079 {
1080 if (!descriptor.m_CifgEnabled)
1081 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001082 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001083 reasonIfUnsupported,
1084 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1085 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001086 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001087 reasonIfUnsupported,
1088 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001089 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001090 reasonIfUnsupported,
1091 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001092 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001093 reasonIfUnsupported,
1094 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1095 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001096
1097 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001098}
1099
saoste012df12b32018-11-28 16:57:20 +00001100bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1101 const TensorInfo& input1,
1102 const TensorInfo& output,
1103 Optional<std::string&> reasonIfUnsupported) const
1104{
Sadik Armagan2999a022019-04-09 14:20:12 +01001105 bool supported = true;
1106
Matthew Jackson9bff1442019-09-12 09:08:23 +01001107 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001108 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001109 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001110 DataType::QAsymmU8,
1111 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001112 };
1113
1114 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1115 "Reference maximum: input 0 is not a supported type.");
1116
1117 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1118 "Reference maximum: input 1 is not a supported type.");
1119
1120 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1121 "Reference maximum: output is not a supported type.");
1122
1123 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1124 "Reference maximum: input 0 and Input 1 types are mismatched");
1125
1126 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1127 "Reference maximum: input and output types are mismatched");
1128
1129 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1130 "Reference maximum: shapes are not suitable for implicit broadcast.");
1131
1132 return supported;
saoste012df12b32018-11-28 16:57:20 +00001133}
1134
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001135bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1136 const TensorInfo& output,
1137 const MeanDescriptor& descriptor,
1138 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001139{
James Conroy4d1ff582019-06-10 17:06:39 +01001140 bool supported = true;
1141 std::string meanLayerStr = "Mean";
1142 std::string outputTensorStr = "output";
1143
Matthew Jackson252df3a2019-09-11 09:19:18 +01001144 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001145 {
1146 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001147 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001148 DataType::QAsymmU8,
1149 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001150 };
1151
1152 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1153 "Reference Mean: input type not supported.");
1154
1155 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1156 "Reference Mean: input and output types are mismatched");
1157
1158 if (descriptor.m_KeepDims)
1159 {
1160 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1161 reasonIfUnsupported,
1162 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1163 output.GetNumDimensions(),
1164 meanLayerStr, outputTensorStr).data());
1165 }
1166 else if (descriptor.m_Axis.empty())
1167 {
1168 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1169 reasonIfUnsupported,
1170 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1171 meanLayerStr, outputTensorStr).data());
1172 }
1173 else
1174 {
1175 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1176
1177 if (outputDim > 0)
1178 {
1179 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1180 reasonIfUnsupported,
1181 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1182 meanLayerStr, outputTensorStr).data());
1183 }
1184 else
1185 {
1186 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1187 reasonIfUnsupported,
1188 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1189 meanLayerStr, outputTensorStr).data());
1190 }
1191 }
1192
1193 return supported;
narpra0132b90462018-09-13 11:07:48 +01001194}
1195
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001196bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001197 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001198 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001199 Optional<std::string&> reasonIfUnsupported) const
1200{
Jim Flynne242f2d2019-05-22 14:24:13 +01001201 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001202}
1203
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001204bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1205 const TensorInfo &output,
1206 Optional<std::string &> reasonIfUnsupported) const
1207{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001208 bool supported = true;
1209
1210 std::array<DataType,5> supportedTypes =
1211 {
1212 DataType::Float32,
1213 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001214 DataType::QAsymmU8,
1215 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001216 DataType::Boolean
1217 };
1218
1219 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1220 "Reference MemCopy: input type not supported");
1221
1222 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1223 "Reference MemCopy: output type not supported");
1224
1225 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1226 "Reference MemCopy: input and output types are mismatched");
1227
1228 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001229}
1230
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001231bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1232 const TensorInfo& input1,
1233 const TensorInfo& output,
1234 Optional<std::string&> reasonIfUnsupported) const
1235{
Sadik Armagan2999a022019-04-09 14:20:12 +01001236 bool supported = true;
1237
Matthew Jackson9bff1442019-09-12 09:08:23 +01001238 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001239 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001240 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001241 DataType::QAsymmU8,
1242 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001243 };
1244
1245 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1246 "Reference minimum: input 0 is not a supported type.");
1247
1248 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1249 "Reference minimum: input 1 is not a supported type.");
1250
1251 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1252 "Reference minimum: output is not a supported type.");
1253
1254 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1255 "Reference minimum: input 0 and Input 1 types are mismatched");
1256
1257 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1258 "Reference minimum: input and output types are mismatched");
1259
1260 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1261 "Reference minimum: shapes are not suitable for implicit broadcast.");
1262
1263 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001264}
1265
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001266bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1267 const TensorInfo& input1,
1268 const TensorInfo& output,
1269 Optional<std::string&> reasonIfUnsupported) const
1270{
Sadik Armagan2999a022019-04-09 14:20:12 +01001271 bool supported = true;
1272
Matthew Jackson252df3a2019-09-11 09:19:18 +01001273 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001274 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001275 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001276 DataType::QAsymmU8,
1277 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001278 };
1279
1280 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1281 "Reference multiplication: input 0 is not a supported type.");
1282
1283 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1284 "Reference multiplication: input 1 is not a supported type.");
1285
1286 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1287 "Reference multiplication: output is not a supported type.");
1288
1289 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1290 "Reference multiplication: input 0 and Input 1 types are mismatched");
1291
1292 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1293 "Reference multiplication: input and output types are mismatched");
1294
1295 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1296 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1297
1298 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001299}
1300
1301bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1302 const TensorInfo& output,
1303 const NormalizationDescriptor& descriptor,
1304 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001305{
Nina Drozd661dfa72018-10-02 11:14:17 +01001306 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001307
1308 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001309 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001310 {
1311 DataType::Float16,
1312 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001313 DataType::QAsymmU8,
1314 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001315 };
1316
1317 bool supported = true;
1318
1319 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1320 "Reference normalization: input type not supported.");
1321
1322 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1323 "Reference normalization: output type not supported.");
1324
1325 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1326 "Reference normalization: input and output shapes have different "
1327 "num total elements.");
1328
1329 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001330}
1331
Derek Lamberti901ea112019-12-10 22:07:09 +00001332bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1333 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001334{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001335 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001336}
1337
1338bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1339 const TensorInfo& output,
1340 const PadDescriptor& descriptor,
1341 Optional<std::string&> reasonIfUnsupported) const
1342{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001343 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001344 bool supported = true;
1345
1346 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001347 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001348 {
1349 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001350 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001351 DataType::QAsymmU8,
1352 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001353 };
1354
1355 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1356 "Reference pad: input is not a supported type.");
1357
1358 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1359 "Reference pad: output is not a supported type.");
1360
1361 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1362 "Reference pad: input and output types are mismatched.");
1363
1364 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001365}
1366
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001367bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1368 const TensorInfo& output,
1369 const PermuteDescriptor& descriptor,
1370 Optional<std::string&> reasonIfUnsupported) const
1371{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001372 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001373 bool supported = true;
1374
1375 // Define supported output and inputs types.
1376 std::array<DataType,3> supportedTypes =
1377 {
1378 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001379 DataType::QAsymmU8,
1380 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001381 };
1382
1383 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1384 "Reference permute: input is not a supported type.");
1385
1386 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1387 "Reference permute: output is not a supported type.");
1388
1389 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1390 "Reference permute: input and output types are mismatched.");
1391
1392 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001393}
1394
1395bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1396 const TensorInfo& output,
1397 const Pooling2dDescriptor& descriptor,
1398 Optional<std::string&> reasonIfUnsupported) const
1399{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001400 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001401 bool supported = true;
1402
1403 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001404 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001405 {
1406 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001407 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001408 DataType::QAsymmU8,
1409 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001410 };
1411
1412 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1413 "Reference poolind2d: input is not a supported type.");
1414
1415 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1416 "Reference poolind2d: output is not a supported type.");
1417
1418 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1419 "Reference poolind2d: input and output types are mismatched.");
1420
1421 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001422}
1423
Derek Lamberti5f400d62019-03-25 15:41:58 +00001424bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1425 const TensorInfo& output,
1426 Optional<std::string&> reasonIfUnsupported) const
1427{
1428 bool supported = true;
1429
Finn Williamsfd271062019-12-04 14:27:27 +00001430 // Define supported input types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001431 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001432 DataType::Float32,
1433 };
1434
1435 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1436 "Reference quantize: input type not supported.");
1437
1438 // Define supported output types.
Finn Williamsfd271062019-12-04 14:27:27 +00001439 std::array<DataType,3> supportedOutputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001440 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001441 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001442 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001443 };
1444 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1445 "Reference quantize: output type not supported.");
1446
1447 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1448 "Reference quantize: input and output shapes have different num total elements.");
1449
1450 return supported;
1451}
1452
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001453bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001454 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001455 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001456 Optional<std::string&> reasonIfUnsupported) const
1457{
Kevin Maya023c402019-12-12 17:28:05 +00001458 ignore_unused(output);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001459 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001460 // Define supported output types.
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001461 std::array<DataType,5> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001462 {
1463 DataType::Float32,
1464 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001465 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001466 DataType::QAsymmU8,
1467 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001468 };
1469 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1470 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001471}
1472
1473bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001474 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001475 Optional<std::string&> reasonIfUnsupported) const
1476{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001477 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001478 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001479 {
1480 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001481 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001482 DataType::QAsymmU8,
1483 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001484 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001485
1486 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1487 "Reference ResizeBilinear: input type not supported");
1488
1489 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1490 "Reference ResizeBilinear: output type not supported");
1491
1492 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1493 "Reference ResizeBilinear: input and output types not matching");
1494
1495 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001496}
1497
Teresa Charlin970f43b2019-07-01 13:51:07 +01001498bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1499 const TensorInfo& output,
1500 const ResizeDescriptor& descriptor,
1501 Optional<std::string&> reasonIfUnsupported) const
1502{
Derek Lamberti901ea112019-12-10 22:07:09 +00001503 boost::ignore_unused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001504 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001505 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001506 {
1507 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001508 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001509 DataType::QAsymmU8,
1510 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001511 };
1512
1513 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1514 "Reference Resize: input type not supported");
1515
1516 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1517 "Reference Resize: output type not supported");
1518
1519 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1520 "Reference Resize: input and output types not matching");
1521
1522 return supported;
1523}
1524
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001525bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1526 const TensorInfo& output,
1527 Optional<std::string&> reasonIfUnsupported) const
1528{
josh minor4a3c6102020-01-06 16:40:46 -06001529 return IsElementwiseUnarySupported(input,
1530 output,
1531 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1532 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001533}
1534
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001535bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1536 const TensorInfo& output,
1537 const SliceDescriptor& descriptor,
1538 Optional<std::string&> reasonIfUnsupported) const
1539{
Derek Lamberti901ea112019-12-10 22:07:09 +00001540 boost::ignore_unused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001541 bool supported = true;
1542
1543 std::array<DataType, 3> supportedTypes =
1544 {
1545 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001546 DataType::QAsymmU8,
1547 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001548 };
1549
1550 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1551 "Reference Slice: input type not supported");
1552
1553 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1554 "Reference Slice: output type not supported");
1555
1556 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1557 "Reference Slice: input and output types are mismatched");
1558
1559 return supported;
1560}
1561
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001562bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1563 const TensorInfo& output,
1564 const SoftmaxDescriptor& descriptor,
1565 Optional<std::string&> reasonIfUnsupported) const
1566{
Derek Lamberti901ea112019-12-10 22:07:09 +00001567 boost::ignore_unused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001568 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001569 std::array<DataType,4> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001570 {
1571 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001572 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001573 DataType::QAsymmU8,
1574 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001575 };
1576
1577 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001578 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001579
1580 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001581 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001582
1583 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001584 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001585
1586 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001587}
1588
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001589bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1590 const TensorInfo& output,
1591 const SpaceToBatchNdDescriptor& descriptor,
1592 Optional<std::string&> reasonIfUnsupported) const
1593{
Derek Lamberti901ea112019-12-10 22:07:09 +00001594 boost::ignore_unused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001595 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001596 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001597 {
1598 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001599 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001600 DataType::QAsymmU8,
1601 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001602 };
1603
1604 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1605 "Reference SpaceToBatchNd: input type not supported");
1606
1607 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1608 "Reference SpaceToBatchNd: output type not supported");
1609
1610 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1611 "Reference SpaceToBatchNd: input and output types are mismatched");
1612
1613 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001614}
1615
Keith Davisa57eccb2019-06-14 17:33:22 +01001616bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001617 const TensorInfo& output,
1618 const SpaceToDepthDescriptor& descriptor,
1619 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001620{
1621
1622 ignore_unused(descriptor);
1623 bool supported = true;
1624
Matthew Jackson9bff1442019-09-12 09:08:23 +01001625 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001626 {
1627 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001628 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001629 DataType::QAsymmU8,
1630 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001631 };
1632
1633 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1634 "Reference SpaceToDepth: input type not supported");
1635
1636 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1637 "Reference SpaceToDepth: output type not supported");
1638
1639 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1640 "Reference SpaceToDepth: input and output types are mismatched");
1641
1642 return supported;
1643}
1644
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001645bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1646 const ViewsDescriptor& descriptor,
1647 Optional<std::string&> reasonIfUnsupported) const
1648{
1649 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001650 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001651 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001652 {
1653 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001654 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001655 DataType::QAsymmU8,
1656 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001657 };
1658
1659 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1660 "Reference splitter: input type not supported");
1661
1662 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001663}
1664
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001665bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1666 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1667 const ViewsDescriptor& descriptor,
1668 Optional<std::string&> reasonIfUnsupported) const
1669{
1670 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001671 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001672 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001673 {
1674 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001675 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001676 DataType::QAsymmU8,
1677 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001678 };
1679
1680 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1681 "Reference splitter: output type not supported");
1682 for (const TensorInfo output : outputs)
1683 {
1684 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1685 "Reference splitter: input type not supported");
1686
1687 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1688 "Reference splitter: input and output types mismatched.");
1689 }
1690
1691 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001692}
1693
Matthew Jackson81e601c2019-07-11 12:07:09 +01001694bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1695 const TensorInfo& output,
1696 const StackDescriptor& descriptor,
1697 Optional<std::string&> reasonIfUnsupported) const
1698{
1699 ignore_unused(descriptor);
1700
1701 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001702 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001703 {
1704 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001705 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001706 DataType::QAsymmU8,
1707 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01001708 };
1709
1710 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1711 "Reference stack: output type not supported");
1712 for (const TensorInfo* input : inputs)
1713 {
1714 BOOST_ASSERT(input != nullptr);
1715 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1716 "Reference stack: input type not supported");
1717
1718 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1719 "Reference stack: input and output types mismatched.");
1720 }
1721
1722 return supported;
1723}
1724
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001725bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1726 const TensorInfo& output,
1727 const StridedSliceDescriptor& descriptor,
1728 Optional<std::string&> reasonIfUnsupported) const
1729{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001730 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001731 bool supported = true;
1732
1733 std::array<DataType,3> supportedTypes =
1734 {
1735 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001736 DataType::QAsymmU8,
1737 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001738 };
1739
1740 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1741 "Reference StridedSlice: input type not supported");
1742
1743 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1744 "Reference StridedSlice: output type not supported");
1745
1746 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1747 "Reference StridedSlice: input and output types are mismatched");
1748
1749 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001750}
1751
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001752bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1753 const TensorInfo& input1,
1754 const TensorInfo& output,
1755 Optional<std::string&> reasonIfUnsupported) const
1756{
Sadik Armagan2999a022019-04-09 14:20:12 +01001757 bool supported = true;
1758
Matthew Jackson9bff1442019-09-12 09:08:23 +01001759 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001760 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001761 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001762 DataType::QAsymmU8,
1763 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001764 };
1765
1766 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1767 "Reference subtraction: input 0 is not a supported type.");
1768
1769 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1770 "Reference subtraction: input 1 is not a supported type.");
1771
1772 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1773 "Reference subtraction: output is not a supported type.");
1774
1775 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1776 "Reference subtraction: input 0 and Input 1 types are mismatched");
1777
1778 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1779 "Reference subtraction: input and output types are mismatched");
1780
1781 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1782 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1783
1784 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001785}
1786
Matteo Martincighab9e5252019-06-13 17:27:46 +01001787bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1788 const TensorInfo& alpha,
1789 const TensorInfo& output,
1790 Optional<std::string&> reasonIfUnsupported) const
1791{
1792 bool supported = true;
1793
Matthew Jackson9bff1442019-09-12 09:08:23 +01001794 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001795 {
1796 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001797 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001798 DataType::QAsymmU8,
1799 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01001800 };
1801
1802 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1803 "PReLU: input is not a supported type.");
1804
1805 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1806 "PReLU: alpha is not a supported type.");
1807
1808 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1809 "PReLU: output is not a supported type.");
1810
1811 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1812 "PReLU: input, alpha and output types are mismatched");
1813
1814 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1815 "PReLU: shapes are not suitable for implicit broadcast");
1816
1817 return supported;
1818}
1819
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001820bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1821 const TensorInfo& output,
1822 const TransposeConvolution2dDescriptor& descriptor,
1823 const TensorInfo& weights,
1824 const Optional<TensorInfo>& biases,
1825 Optional<std::string&> reasonIfUnsupported) const
1826{
Derek Lamberti901ea112019-12-10 22:07:09 +00001827 boost::ignore_unused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001828 bool supported = true;
1829
Matthew Jackson252df3a2019-09-11 09:19:18 +01001830 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001831 {
1832 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001833 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001834 DataType::QAsymmU8,
1835 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001836 };
1837
1838 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1839 "Reference TransposeConvolution2d: input is not a supported type.");
1840
1841 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1842 "Reference TransposeConvolution2d: output is not a supported type.");
1843
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001844 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1845 "Reference TransposeConvolution2d: input and output types mismatched.");
1846
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001847
1848 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +00001849 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001850 {
Derek Lambertid466a542020-01-22 15:37:29 +00001851 ARMNN_NO_DEPRECATE_WARN_BEGIN
1852 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001853 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001854 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00001855 DataType::QSymmS8,
1856 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001857 };
Derek Lambertid466a542020-01-22 15:37:29 +00001858 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001859
1860 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1861 "Reference TransposeConvolution2d: weights type not supported for "
1862 "quantized input.");
1863 }
1864 else
1865 {
1866 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1867 "Reference TransposeConvolution2d: weights is not a supported type.");
1868
1869 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1870 "Reference TransposeConvolution2d: input and weights types mismatched.");
1871 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001872
1873 if (biases.has_value())
1874 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001875 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001876 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001877 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001878 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001879 DataType::Signed32
1880 };
1881 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1882 "Reference TransposeConvolution2d: biases is not a supported type.");
1883 }
1884
1885 return supported;
1886}
1887
arovir011c7c81b2018-10-08 11:34:28 +01001888} // namespace armnn