blob: 491081dbac5eb449a6f9bf79521656424607f78e [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000011#include <armnn/BackendRegistry.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matteo Martincighe011d202019-11-28 11:35:47 +000013#include <armnnUtils/DataLayoutIndexed.hpp>
14
15#include <InternalTypes.hpp>
16#include <LayerSupportCommon.hpp>
17
Derek Lambertif674aa02019-08-01 15:56:25 +010018#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000019
Matteo Martincighe011d202019-11-28 11:35:47 +000020#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000022
Derek Lamberti50db4e82019-03-13 14:16:15 +000023#include <vector>
24#include <algorithm>
25#include <array>
26
telsoa014fcda012018-03-09 14:13:49 +000027using namespace boost;
28
29namespace armnn
30{
31
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010032namespace
33{
34
35template<typename Float32Func, typename Uint8Func, typename ... Params>
36bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
37 DataType dataType,
38 Float32Func floatFuncPtr,
39 Uint8Func uint8FuncPtr,
40 Params&&... params)
41{
42 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
43 dataType,
44 &FalseFunc<Params...>,
45 floatFuncPtr,
46 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000047 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000048 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010049 std::forward<Params>(params)...);
50}
51
52} // anonymous namespace
53
James Conroy4d1ff582019-06-10 17:06:39 +010054namespace
55{
56
57std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
58 unsigned int actual,
59 std::string& layerStr,
60 std::string& tensorName)
61{
62 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
63 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
64
65 return errorMsg;
66}
67
68} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000069
Sadik Armagan9199e582019-09-05 17:35:31 +010070bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
71 Optional<std::string&> reasonIfUnsupported) const
72{
josh minor4a3c6102020-01-06 16:40:46 -060073 return IsElementwiseUnarySupported(input,
74 output,
75 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
76 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010077}
78
arovir011c7c81b2018-10-08 11:34:28 +010079bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
80 const TensorInfo& output,
81 const ActivationDescriptor& descriptor,
82 Optional<std::string&> reasonIfUnsupported) const
83{
Derek Lamberti50db4e82019-03-13 14:16:15 +000084 bool supported = true;
85
86 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +010087 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +000088 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010089 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +000090 DataType::QAsymmU8,
91 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000092 };
93
94 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
95 "Reference activation: input type not supported.");
96
97 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
98 "Reference activation: output type not supported.");
99
100 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
101 "Reference activation: input and output types mismatched.");
102
103 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
104 "Reference activation: input and output shapes are of different rank.");
105
106
107 struct ActivationFunctionSupported : public Rule
108 {
109 ActivationFunctionSupported(const ActivationDescriptor& desc)
110 {
111 switch(desc.m_Function)
112 {
113 case ActivationFunction::Abs:
114 case ActivationFunction::BoundedReLu:
115 case ActivationFunction::LeakyReLu:
116 case ActivationFunction::Linear:
117 case ActivationFunction::ReLu:
118 case ActivationFunction::Sigmoid:
119 case ActivationFunction::SoftReLu:
120 case ActivationFunction::Sqrt:
121 case ActivationFunction::Square:
122 case ActivationFunction::TanH:
123 {
124 m_Res = true;
125 break;
126 }
127 default:
128 {
129 m_Res = false;
130 break;
131 }
132 }
133 }
134 };
135
136 // Function is supported
137 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
138 "Reference activation: function not supported.");
139
140 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100141}
142
143bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
144 const TensorInfo& input1,
145 const TensorInfo& output,
146 Optional<std::string&> reasonIfUnsupported) const
147{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000148 bool supported = true;
149
Matthew Jackson252df3a2019-09-11 09:19:18 +0100150 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000151 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100152 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000153 DataType::QAsymmU8,
154 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000155 };
156
157 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
158 "Reference addition: input 0 is not a supported type.");
159
160 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
161 "Reference addition: input 1 is not a supported type.");
162
163 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
164 "Reference addition: output is not a supported type.");
165
166 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
167 "Reference addition: input 0 and Input 1 types are mismatched");
168
169 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
170 "Reference addition: input and output types are mismatched");
171
172 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
173 "Reference addition: shapes are not suitable for implicit broadcast.");
174
175 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100176}
177
Nikhil Raj68c2c902019-09-19 11:21:11 +0100178bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
179 const armnn::ArgMinMaxDescriptor &descriptor,
180 armnn::Optional<std::string &> reasonIfUnsupported) const
181{
182 ignore_unused(descriptor);
183
Francis Murtagh1939df52019-11-13 15:21:09 +0000184 std::array<DataType, 4> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100185 {
186 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000187 DataType::QAsymmU8,
188 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000189 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100190 };
191
192 bool supported = true;
193
194 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
195 "Reference ArgMinMax: input is not a supported type.");
196 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
197 "Reference ArgMinMax: output type not supported");
198
199 return supported;
200}
201
arovir011c7c81b2018-10-08 11:34:28 +0100202bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
203 const TensorInfo& output,
204 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100205 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100206 const TensorInfo& beta,
207 const TensorInfo& gamma,
208 const BatchNormalizationDescriptor& descriptor,
209 Optional<std::string&> reasonIfUnsupported) const
210{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100211 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100212
Matthew Jackson9bff1442019-09-12 09:08:23 +0100213 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100214 {
215 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100216 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000217 DataType::QAsymmU8,
218 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100219 };
220
221 bool supported = true;
222
223 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
224 "Reference batch normalization: input is not a supported type.");
225
226 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
227 "Reference batch normalization: output is not a supported type.");
228
229 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
230 "Reference batch normalization: input and output types are mismatched");
231
232 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
233 "Reference batch normalization: mean is not a supported type.");
234
235 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
236 "Reference batch normalization: variance is not a supported type.");
237
238 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
239 "Reference batch normalization: beta is not a supported type.");
240
241 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
242 "Reference batch normalization: gamma is not a supported type.");
243
244 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100245}
246
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000247bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
248 const TensorInfo& output,
249 const BatchToSpaceNdDescriptor& descriptor,
250 Optional<std::string&> reasonIfUnsupported) const
251{
252 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100253
254 bool supported = true;
255
256 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
257 std::string inputTensorStr = "input";
258 std::string outputTensorStr = "output";
259
260 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100261 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100262 {
263 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100264 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000265 DataType::QAsymmU8,
266 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100267 };
268
269 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
270 "Reference BatchToSpaceNd: input type not supported.");
271
272 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
273 "Reference BatchToSpaceNd: output type not supported.");
274
275 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
276 "Reference BatchToSpaceNd: input and output types mismatched.");
277
278 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
279 reasonIfUnsupported,
280 CreateIncorrectDimensionsErrorMsg(4,
281 output.GetNumDimensions(),
282 batchToSpaceNdLayerStr,
283 outputTensorStr).data());
284
285 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
286 reasonIfUnsupported,
287 CreateIncorrectDimensionsErrorMsg(4,
288 input.GetNumDimensions(),
289 batchToSpaceNdLayerStr,
290 inputTensorStr).data());
291
292 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000293}
294
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100295bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
296 const TensorInfo& input1,
297 const TensorInfo& output,
298 const ComparisonDescriptor& descriptor,
299 Optional<std::string&> reasonIfUnsupported) const
300{
301 boost::ignore_unused(descriptor);
302
303 std::array<DataType, 4> supportedInputTypes =
304 {
305 DataType::Float32,
306 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000307 DataType::QAsymmU8,
308 DataType::QSymmS16
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100309 };
310
311 bool supported = true;
312 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
313 "Reference comparison: input 0 is not a supported type");
314
315 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
316 "Reference comparison: input 0 and Input 1 types are mismatched");
317
318 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
319 "Reference comparison: output is not of type Boolean");
320
321 return supported;
322}
323
Jim Flynn906f9462019-05-10 13:55:21 +0100324bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
325 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100326 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100327 Optional<std::string&> reasonIfUnsupported) const
328{
Jim Flynne242f2d2019-05-22 14:24:13 +0100329 ignore_unused(descriptor);
330
331 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100332 std::array<DataType,4> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100333 {
334 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100335 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000336 DataType::QAsymmU8,
337 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100338 };
339
340 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
341 "Reference concatenation: output type not supported");
342 for (const TensorInfo* input : inputs)
343 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100344 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100345 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
346 "Reference concatenation: input type not supported");
347
348 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
349 "Reference concatenation: input and output types mismatched.");
350 }
351
352 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100353}
354
arovir011c7c81b2018-10-08 11:34:28 +0100355bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
356 Optional<std::string&> reasonIfUnsupported) const
357{
Jim Flynne242f2d2019-05-22 14:24:13 +0100358 std::array<DataType,4> supportedTypes =
359 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100360 DataType::Float32,
361 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000362 DataType::QAsymmU8,
363 DataType::QSymmS16
Nina Drozd58ef2c62019-05-16 12:09:18 +0100364 };
365
366 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
367 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100368}
369
370bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
371 const TensorInfo& output,
372 Optional<std::string&> reasonIfUnsupported) const
373{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100374 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
375 input.GetDataType(),
376 &TrueFunc<>,
377 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000378 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000379 &FalseFuncI32<>,
380 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100381 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
382 output.GetDataType(),
383 &FalseOutputFuncF16<>,
384 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000385 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000386 &FalseFuncI32<>,
387 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100388}
389
390bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
391 const TensorInfo& output,
392 Optional<std::string&> reasonIfUnsupported) const
393{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100394 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
395 input.GetDataType(),
396 &FalseInputFuncF16<>,
397 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000398 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000399 &FalseFuncI32<>,
400 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100401 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
402 output.GetDataType(),
403 &TrueFunc<>,
404 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000405 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000406 &FalseFuncI32<>,
407 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100408}
409
410bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
411 const TensorInfo& output,
412 const Convolution2dDescriptor& descriptor,
413 const TensorInfo& weights,
414 const Optional<TensorInfo>& biases,
415 Optional<std::string&> reasonIfUnsupported) const
416{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100417 bool supported = true;
418
419 // Define supported types.
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000420 std::array<DataType,4> supportedTypes =
421 {
422 DataType::Float32,
423 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000424 DataType::QAsymmU8,
425 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100426 };
427
428 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100429 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100430
431 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100432 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100433
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100434 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100435 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100436
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000437 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +0000438 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000439 {
440 std::array<DataType, 2> supportedWeightTypes =
441 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000442 DataType::QAsymmU8,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000443 DataType::QuantizedSymm8PerAxis
444 };
445
446 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
447 "Reference convolution2d: weights type not supported for quantized input.");
448 }
449 else
450 {
451 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
452 "Reference convolution2d: weights is not a supported type.");
453
454 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
455 "Reference convolution2d: input and weights types mismatched.");
456 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100457
458 if (biases.has_value())
459 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000460 std::array<DataType,3> biasesSupportedTypes =
461 {
462 DataType::Float32,
463 DataType::Float16,
464 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100465 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000466
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100467 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100468 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100469 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100470 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100471
472 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100473}
474
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000475bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
476 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000477 Optional<std::string&> reasonIfUnsupported) const
478{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100479 bool supported = true;
480
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000481 std::array<DataType, 5> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100482 {
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000483 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100484 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000485 DataType::QAsymmU8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000486 DataType::QSymmS16,
487 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100488 };
489
490 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
491 "Reference debug: input type not supported");
492
493 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
494 "Reference debug: output type not supported");
495
496 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
497 "Reference debug: input and output types are mismatched");
498
499 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000500}
501
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100502bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
503 const TensorInfo& output,
504 const DepthToSpaceDescriptor& descriptor,
505 Optional<std::string&> reasonIfUnsupported) const
506{
507 ignore_unused(descriptor);
508 bool supported = true;
509
510 std::array<DataType,4> supportedTypes =
511 {
512 DataType::Float32,
513 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000514 DataType::QAsymmU8,
515 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100516 };
517
518 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
519 "Reference DepthToSpace: input type not supported");
520
521 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
522 "Reference DepthToSpace: output type not supported");
523
524 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
525 "Reference DepthToSpace: input and output types are mismatched");
526
527 return supported;
528}
529
arovir011c7c81b2018-10-08 11:34:28 +0100530bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
531 const TensorInfo& output,
532 const DepthwiseConvolution2dDescriptor& descriptor,
533 const TensorInfo& weights,
534 const Optional<TensorInfo>& biases,
535 Optional<std::string&> reasonIfUnsupported) const
536{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100537 bool supported = true;
538
539 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100540 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100541 {
542 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100543 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000544 DataType::QAsymmU8,
545 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100546 };
547
548 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
549 "Reference DepthwiseConvolution2d: input is not a supported type.");
550
551 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
552 "Reference DepthwiseConvolution2d: output is not a supported type.");
553
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100554 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
555 "Reference DepthwiseConvolution2d: input and output types mismatched.");
556
Teresa Charlind8df0262019-11-11 12:28:15 +0000557 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +0000558 if (inputType == DataType::QAsymmU8)
Teresa Charlind8df0262019-11-11 12:28:15 +0000559 {
560 std::array<DataType, 2> supportedWeightTypes =
561 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000562 DataType::QAsymmU8,
Teresa Charlind8df0262019-11-11 12:28:15 +0000563 DataType::QuantizedSymm8PerAxis
564 };
565
566 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
567 "Reference convolution2d: weights type not supported for quantized input.");
568 }
569 else
570 {
571 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
572 "Reference DepthwiseConvolution2d: weights is not a supported type.");
573
574 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
575 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
576 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100577
578 if (biases.has_value())
579 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100580 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100581 {
582 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100583 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100584 DataType::Signed32
585 };
586 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
587 "Reference DepthwiseConvolution2d: biases is not a supported type.");
588 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100589 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100590
591 return supported;
592
arovir011c7c81b2018-10-08 11:34:28 +0100593}
594
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000595bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
596 const TensorInfo& output,
597 Optional<std::string&> reasonIfUnsupported) const
598{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100599 bool supported = true;
600
Finn Williamsfd271062019-12-04 14:27:27 +0000601 std::array<DataType,3> supportedInputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000602 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000603 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000604 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100605 };
606
607 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
608 "Reference dequantize: input type not supported.");
609
Jan Eilersf7107932019-11-01 11:09:36 +0000610 std::array<DataType,2> supportedOutputTypes = {
611 DataType::Float32,
612 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100613 };
614
615 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
616 "Reference dequantize: output type not supported.");
617
618 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
619 "Reference dequantize: input and output shapes have different num total elements.");
620
621 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000622}
623
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000624bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
625 const TensorInfo& scores,
626 const TensorInfo& anchors,
627 const TensorInfo& detectionBoxes,
628 const TensorInfo& detectionClasses,
629 const TensorInfo& detectionScores,
630 const TensorInfo& numDetections,
631 const DetectionPostProcessDescriptor& descriptor,
632 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000633{
Derek Lamberti901ea112019-12-10 22:07:09 +0000634 boost::ignore_unused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
635
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100636 bool supported = true;
637
Mike Kelly4992c342019-08-14 11:33:11 +0100638 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100639 {
640 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000641 DataType::QAsymmU8,
642 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100643 };
644
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000645 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100646 "Reference DetectionPostProcess: input 0 is not a supported type.");
647
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000648 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100649 "Reference DetectionPostProcess: input 1 is not a supported type.");
650
651 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000652}
653
Pablo Tellof0bd6832019-04-26 17:58:13 +0100654bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
655 const TensorInfo& output,
656 const DepthwiseConvolution2dDescriptor& descriptor,
657 const TensorInfo& weights,
658 const Optional<TensorInfo>& biases,
659 Optional<std::string&> reasonIfUnsupported) const
660{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100661 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100662}
663
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100664bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100665 const TensorInfo& input1,
666 const TensorInfo& output,
667 Optional<std::string&> reasonIfUnsupported) const
668{
Sadik Armagan2999a022019-04-09 14:20:12 +0100669 bool supported = true;
670
Matthew Jackson9bff1442019-09-12 09:08:23 +0100671 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100672 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100673 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000674 DataType::QAsymmU8,
675 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +0100676 };
677
678 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
679 "Reference division: input 0 is not a supported type.");
680
681 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
682 "Reference division: input 1 is not a supported type.");
683
684 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
685 "Reference division: output is not a supported type.");
686
687 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
688 "Reference division: input 0 and Input 1 types are mismatched");
689
690 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
691 "Reference division: input and output types are mismatched");
692
693 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
694 "Reference division: shapes are not suitable for implicit broadcast.");
695
696 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100697}
698
josh minor4a3c6102020-01-06 16:40:46 -0600699bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
700 const TensorInfo& output,
701 const ElementwiseUnaryDescriptor& descriptor,
702 Optional<std::string&> reasonIfUnsupported) const
703{
704 boost::ignore_unused(descriptor);
705
706 std::array<DataType, 4> supportedTypes =
707 {
708 DataType::Float32,
709 DataType::Float16,
710 DataType::QAsymmU8,
711 DataType::QSymmS16
712 };
713
714 bool supported = true;
715
716 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
717 "Reference elementwise unary: input type not supported");
718
719 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
720 "Reference elementwise unary: output type not supported");
721
722 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
723 "Reference elementwise unary: input and output types not matching");
724
725 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
726 "Reference elementwise unary: input and output shapes"
727 "have different number of total elements");
728
729 return supported;
730}
731
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000732bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
733 const TensorInfo& input1,
734 const TensorInfo& output,
735 Optional<std::string&> reasonIfUnsupported) const
736{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100737 return IsComparisonSupported(input0,
738 input1,
739 output,
740 ComparisonDescriptor(ComparisonOperation::Equal),
741 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000742}
743
arovir011c7c81b2018-10-08 11:34:28 +0100744bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
745 const FakeQuantizationDescriptor& descriptor,
746 Optional<std::string&> reasonIfUnsupported) const
747{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100748 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100749 bool supported = true;
750
751 std::array<DataType,1> supportedTypes =
752 {
753 DataType::Float32
754 };
755
756 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
757 "Reference fake quantization: input type not supported.");
758
759 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100760}
761
762bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
763 const TensorInfo& output,
764 Optional<std::string&> reasonIfUnsupported) const
765{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100766 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100767 bool supported = true;
768
Matthew Jackson9bff1442019-09-12 09:08:23 +0100769 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100770 {
James Conroyb40d7102019-06-04 12:32:09 +0100771 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100772 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000773 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +0100774 };
775
776 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
777 "Reference Floor: input type not supported.");
778
779 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
780 "Reference Floor: output type not supported.");
781
782 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100783}
784
785bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
786 const TensorInfo& output,
787 const TensorInfo& weights,
788 const TensorInfo& biases,
789 const FullyConnectedDescriptor& descriptor,
790 Optional<std::string&> reasonIfUnsupported) const
791{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100792 bool supported = true;
793
794 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100795 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100796 {
797 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100798 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000799 DataType::QAsymmU8,
800 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100801 };
802
803 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
804 "Reference Fully Connected: input type not supported.");
805
806 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
807 "Reference Fully Connected: output type not supported.");
808
809 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
810 "Reference Fully Connected: input and output types mismatched.");
811
812 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
813 "Reference Fully Connected: weights type not supported.");
814
815 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
816 "Reference Fully Connected: input and weight types mismatched.");
817
818 if (descriptor.m_BiasEnabled)
819 {
820 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100821 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100822 supportedBiasTypes =
823 {
824 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100825 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100826 DataType::Signed32
827 };
828
829 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
830 "Reference Fully Connected: bias type not supported.");
831
832 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
833 "Reference Fully Connected: bias and weight types mismatch.");
834
835 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
836 "Reference Fully Connected: bias type inferred from weights is incompatible.");
837
838 }
839
840 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100841}
842
narpra014951d842019-01-18 16:53:53 +0000843bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
844 const armnn::TensorInfo& input1,
845 const armnn::TensorInfo& output,
846 armnn::Optional<std::string&> reasonIfUnsupported) const
847{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100848 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100849 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100850 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100851 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100852 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000853 DataType::QAsymmU8,
854 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100855 };
856
857 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
858 "Reference Gather: input type not supported");
859
860 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
861 "Reference Gather: output type not supported");
862
863 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
864 "Reference Gather: indices (input1) type not supported");
865
866 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
867 "Reference Gather: input and output types not matching");
868
869 return supported;
narpra014951d842019-01-18 16:53:53 +0000870}
871
FrancisMurtagh878f0232018-12-19 10:56:15 +0000872bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
873 const TensorInfo& input1,
874 const TensorInfo& output,
875 Optional<std::string&> reasonIfUnsupported) const
876{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100877 return IsComparisonSupported(input0,
878 input1,
879 output,
880 ComparisonDescriptor(ComparisonOperation::Greater),
881 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000882}
883
Derek Lamberti901ea112019-12-10 22:07:09 +0000884bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
885 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +0100886{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100887 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100888}
889
Kevin May09ca49c2019-10-09 12:37:34 +0100890bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
891 const TensorInfo& output,
892 const InstanceNormalizationDescriptor& descriptor,
893 Optional<std::string&> reasonIfUnsupported) const
894{
895 ignore_unused(descriptor);
896 // Define supported types
897 std::array<DataType, 4> supportedTypes =
898 {
899 DataType::Float32,
900 DataType::Float16
901 };
902
903 bool supported = true;
904
905 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
906 "Reference Instance Normalization: input type not supported.");
907
908 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
909 "Reference Instance Normalization: output type not supported.");
910
911 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
912 "Reference Instance Normalization: input and output types mismatched.");
913
914 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
915 "Reference Instance Normalization: input and output shapes have different "
916 "num total elements.");
917
918 return supported;
919}
920
arovir011c7c81b2018-10-08 11:34:28 +0100921bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
922 const TensorInfo& output,
923 const L2NormalizationDescriptor& descriptor,
924 Optional<std::string&> reasonIfUnsupported) const
925{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100926 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100927 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100928 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100929 {
930 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100931 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000932 DataType::QAsymmU8,
933 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100934 };
935
936 bool supported = true;
937
938 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
939 "Reference L2normalization: input type not supported.");
940
941 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
942 "Reference L2normalization: output type not supported.");
943
944 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
945 "Reference L2normalization: input and output types mismatched.");
946
947 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
948 "Reference L2normalization: input and output shapes have different "
949 "num total elements.");
950
951 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100952}
953
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100954bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
955 const TensorInfo& output,
956 const LogSoftmaxDescriptor& descriptor,
957 Optional<std::string&> reasonIfUnsupported) const
958{
959 ignore_unused(descriptor);
960
961 std::array<DataType, 2> supportedTypes =
962 {
963 DataType::Float32,
964 DataType::Float16
965 };
966
967 bool supported = true;
968 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
969 "Reference LogSoftmax: input type not supported");
970
971 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
972 "Reference LogSoftmax: output type not supported");
973
974 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
975 "Reference LogSoftmax: input and output types do not match");
976
977 return supported;
978}
979
arovir011c7c81b2018-10-08 11:34:28 +0100980bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
981 const TensorInfo& outputStateIn,
982 const TensorInfo& cellStateIn,
983 const TensorInfo& scratchBuffer,
984 const TensorInfo& outputStateOut,
985 const TensorInfo& cellStateOut,
986 const TensorInfo& output,
987 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100988 const LstmInputParamsInfo& paramsInfo,
989 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100990{
telsoa01c577f2c2018-08-31 09:22:23 +0100991 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100992 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100993
994 bool supported = true;
995
996 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100997 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000998 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100999 };
1000
Jan Eilersd01a83c2019-07-03 18:20:40 +01001001 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001002 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1003 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001004 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1005 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001006 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1007 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001008 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1009 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001010 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1011 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001012 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1013 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001014 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1015 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001016 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001017 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001018 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001019 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001020 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001021 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001022 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001023 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001024 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001025 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001026 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001027 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001028 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001029 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001030 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001031 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001032 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001033 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001034 "Reference Lstm: input and OutputGateBias types are mismatched");
1035 if (!descriptor.m_CifgEnabled)
1036 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001037 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001038 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001039 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001040 reasonIfUnsupported,
1041 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001042 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001043 "Reference Lstm: input and InputGateBias types are mismatched");
1044 if (descriptor.m_PeepholeEnabled)
1045 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001046 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001047 reasonIfUnsupported,
1048 "Reference Lstm: input and CellToInputWeights types are mismatched");
1049 }
1050 }
1051 if (descriptor.m_PeepholeEnabled)
1052 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001053 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001054 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001055 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001056 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1057 }
1058 if (descriptor.m_ProjectionEnabled)
1059 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001060 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001061 "Reference Lstm: input and mProjectionWeights types are mismatched");
1062 if (paramsInfo.m_ProjectionBias != nullptr)
1063 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001064 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001065 "Reference Lstm: input and ProjectionBias types are mismatched");
1066 }
1067 }
1068 if (descriptor.m_LayerNormEnabled)
1069 {
1070 if (!descriptor.m_CifgEnabled)
1071 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001072 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001073 reasonIfUnsupported,
1074 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1075 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001076 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001077 reasonIfUnsupported,
1078 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001079 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001080 reasonIfUnsupported,
1081 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001082 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001083 reasonIfUnsupported,
1084 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1085 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001086
1087 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001088}
1089
saoste012df12b32018-11-28 16:57:20 +00001090bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1091 const TensorInfo& input1,
1092 const TensorInfo& output,
1093 Optional<std::string&> reasonIfUnsupported) const
1094{
Sadik Armagan2999a022019-04-09 14:20:12 +01001095 bool supported = true;
1096
Matthew Jackson9bff1442019-09-12 09:08:23 +01001097 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001098 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001099 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001100 DataType::QAsymmU8,
1101 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001102 };
1103
1104 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1105 "Reference maximum: input 0 is not a supported type.");
1106
1107 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1108 "Reference maximum: input 1 is not a supported type.");
1109
1110 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1111 "Reference maximum: output is not a supported type.");
1112
1113 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1114 "Reference maximum: input 0 and Input 1 types are mismatched");
1115
1116 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1117 "Reference maximum: input and output types are mismatched");
1118
1119 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1120 "Reference maximum: shapes are not suitable for implicit broadcast.");
1121
1122 return supported;
saoste012df12b32018-11-28 16:57:20 +00001123}
1124
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001125bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1126 const TensorInfo& output,
1127 const MeanDescriptor& descriptor,
1128 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001129{
James Conroy4d1ff582019-06-10 17:06:39 +01001130 bool supported = true;
1131 std::string meanLayerStr = "Mean";
1132 std::string outputTensorStr = "output";
1133
Matthew Jackson252df3a2019-09-11 09:19:18 +01001134 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001135 {
1136 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001137 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001138 DataType::QAsymmU8,
1139 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001140 };
1141
1142 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1143 "Reference Mean: input type not supported.");
1144
1145 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1146 "Reference Mean: input and output types are mismatched");
1147
1148 if (descriptor.m_KeepDims)
1149 {
1150 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1151 reasonIfUnsupported,
1152 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1153 output.GetNumDimensions(),
1154 meanLayerStr, outputTensorStr).data());
1155 }
1156 else if (descriptor.m_Axis.empty())
1157 {
1158 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1159 reasonIfUnsupported,
1160 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1161 meanLayerStr, outputTensorStr).data());
1162 }
1163 else
1164 {
1165 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1166
1167 if (outputDim > 0)
1168 {
1169 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1170 reasonIfUnsupported,
1171 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1172 meanLayerStr, outputTensorStr).data());
1173 }
1174 else
1175 {
1176 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1177 reasonIfUnsupported,
1178 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1179 meanLayerStr, outputTensorStr).data());
1180 }
1181 }
1182
1183 return supported;
narpra0132b90462018-09-13 11:07:48 +01001184}
1185
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001186bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001187 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001188 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001189 Optional<std::string&> reasonIfUnsupported) const
1190{
Jim Flynne242f2d2019-05-22 14:24:13 +01001191 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001192}
1193
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001194bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1195 const TensorInfo &output,
1196 Optional<std::string &> reasonIfUnsupported) const
1197{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001198 bool supported = true;
1199
1200 std::array<DataType,5> supportedTypes =
1201 {
1202 DataType::Float32,
1203 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001204 DataType::QAsymmU8,
1205 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001206 DataType::Boolean
1207 };
1208
1209 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1210 "Reference MemCopy: input type not supported");
1211
1212 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1213 "Reference MemCopy: output type not supported");
1214
1215 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1216 "Reference MemCopy: input and output types are mismatched");
1217
1218 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001219}
1220
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001221bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1222 const TensorInfo& input1,
1223 const TensorInfo& output,
1224 Optional<std::string&> reasonIfUnsupported) const
1225{
Sadik Armagan2999a022019-04-09 14:20:12 +01001226 bool supported = true;
1227
Matthew Jackson9bff1442019-09-12 09:08:23 +01001228 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001229 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001230 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001231 DataType::QAsymmU8,
1232 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001233 };
1234
1235 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1236 "Reference minimum: input 0 is not a supported type.");
1237
1238 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1239 "Reference minimum: input 1 is not a supported type.");
1240
1241 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1242 "Reference minimum: output is not a supported type.");
1243
1244 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1245 "Reference minimum: input 0 and Input 1 types are mismatched");
1246
1247 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1248 "Reference minimum: input and output types are mismatched");
1249
1250 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1251 "Reference minimum: shapes are not suitable for implicit broadcast.");
1252
1253 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001254}
1255
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001256bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1257 const TensorInfo& input1,
1258 const TensorInfo& output,
1259 Optional<std::string&> reasonIfUnsupported) const
1260{
Sadik Armagan2999a022019-04-09 14:20:12 +01001261 bool supported = true;
1262
Matthew Jackson252df3a2019-09-11 09:19:18 +01001263 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001264 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001265 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001266 DataType::QAsymmU8,
1267 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001268 };
1269
1270 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1271 "Reference multiplication: input 0 is not a supported type.");
1272
1273 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1274 "Reference multiplication: input 1 is not a supported type.");
1275
1276 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1277 "Reference multiplication: output is not a supported type.");
1278
1279 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1280 "Reference multiplication: input 0 and Input 1 types are mismatched");
1281
1282 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1283 "Reference multiplication: input and output types are mismatched");
1284
1285 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1286 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1287
1288 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001289}
1290
1291bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1292 const TensorInfo& output,
1293 const NormalizationDescriptor& descriptor,
1294 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001295{
Nina Drozd661dfa72018-10-02 11:14:17 +01001296 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001297
1298 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001299 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001300 {
1301 DataType::Float16,
1302 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001303 DataType::QAsymmU8,
1304 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001305 };
1306
1307 bool supported = true;
1308
1309 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1310 "Reference normalization: input type not supported.");
1311
1312 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1313 "Reference normalization: output type not supported.");
1314
1315 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1316 "Reference normalization: input and output shapes have different "
1317 "num total elements.");
1318
1319 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001320}
1321
Derek Lamberti901ea112019-12-10 22:07:09 +00001322bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1323 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001324{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001325 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001326}
1327
1328bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1329 const TensorInfo& output,
1330 const PadDescriptor& descriptor,
1331 Optional<std::string&> reasonIfUnsupported) const
1332{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001333 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001334 bool supported = true;
1335
1336 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001337 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001338 {
1339 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001340 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001341 DataType::QAsymmU8,
1342 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001343 };
1344
1345 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1346 "Reference pad: input is not a supported type.");
1347
1348 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1349 "Reference pad: output is not a supported type.");
1350
1351 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1352 "Reference pad: input and output types are mismatched.");
1353
1354 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001355}
1356
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001357bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1358 const TensorInfo& output,
1359 const PermuteDescriptor& descriptor,
1360 Optional<std::string&> reasonIfUnsupported) const
1361{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001362 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001363 bool supported = true;
1364
1365 // Define supported output and inputs types.
1366 std::array<DataType,3> supportedTypes =
1367 {
1368 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001369 DataType::QAsymmU8,
1370 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001371 };
1372
1373 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1374 "Reference permute: input is not a supported type.");
1375
1376 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1377 "Reference permute: output is not a supported type.");
1378
1379 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1380 "Reference permute: input and output types are mismatched.");
1381
1382 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001383}
1384
1385bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1386 const TensorInfo& output,
1387 const Pooling2dDescriptor& descriptor,
1388 Optional<std::string&> reasonIfUnsupported) const
1389{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001390 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001391 bool supported = true;
1392
1393 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001394 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001395 {
1396 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001397 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001398 DataType::QAsymmU8,
1399 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001400 };
1401
1402 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1403 "Reference poolind2d: input is not a supported type.");
1404
1405 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1406 "Reference poolind2d: output is not a supported type.");
1407
1408 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1409 "Reference poolind2d: input and output types are mismatched.");
1410
1411 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001412}
1413
Derek Lamberti5f400d62019-03-25 15:41:58 +00001414bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1415 const TensorInfo& output,
1416 Optional<std::string&> reasonIfUnsupported) const
1417{
1418 bool supported = true;
1419
Finn Williamsfd271062019-12-04 14:27:27 +00001420 // Define supported input types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001421 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001422 DataType::Float32,
1423 };
1424
1425 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1426 "Reference quantize: input type not supported.");
1427
1428 // Define supported output types.
Finn Williamsfd271062019-12-04 14:27:27 +00001429 std::array<DataType,3> supportedOutputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001430 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001431 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001432 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001433 };
1434 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1435 "Reference quantize: output type not supported.");
1436
1437 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1438 "Reference quantize: input and output shapes have different num total elements.");
1439
1440 return supported;
1441}
1442
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001443bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001444 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001445 Optional<std::string&> reasonIfUnsupported) const
1446{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001447 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001448 // Define supported output types.
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001449 std::array<DataType,5> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001450 {
1451 DataType::Float32,
1452 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001453 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001454 DataType::QAsymmU8,
1455 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001456 };
1457 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1458 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001459}
1460
1461bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001462 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001463 Optional<std::string&> reasonIfUnsupported) const
1464{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001465 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001466 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001467 {
1468 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001469 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001470 DataType::QAsymmU8,
1471 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001472 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001473
1474 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1475 "Reference ResizeBilinear: input type not supported");
1476
1477 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1478 "Reference ResizeBilinear: output type not supported");
1479
1480 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1481 "Reference ResizeBilinear: input and output types not matching");
1482
1483 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001484}
1485
Teresa Charlin970f43b2019-07-01 13:51:07 +01001486bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1487 const TensorInfo& output,
1488 const ResizeDescriptor& descriptor,
1489 Optional<std::string&> reasonIfUnsupported) const
1490{
Derek Lamberti901ea112019-12-10 22:07:09 +00001491 boost::ignore_unused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001492 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001493 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001494 {
1495 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001496 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001497 DataType::QAsymmU8,
1498 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001499 };
1500
1501 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1502 "Reference Resize: input type not supported");
1503
1504 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1505 "Reference Resize: output type not supported");
1506
1507 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1508 "Reference Resize: input and output types not matching");
1509
1510 return supported;
1511}
1512
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001513bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1514 const TensorInfo& output,
1515 Optional<std::string&> reasonIfUnsupported) const
1516{
josh minor4a3c6102020-01-06 16:40:46 -06001517 return IsElementwiseUnarySupported(input,
1518 output,
1519 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1520 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001521}
1522
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001523bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1524 const TensorInfo& output,
1525 const SliceDescriptor& descriptor,
1526 Optional<std::string&> reasonIfUnsupported) const
1527{
Derek Lamberti901ea112019-12-10 22:07:09 +00001528 boost::ignore_unused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001529 bool supported = true;
1530
1531 std::array<DataType, 3> supportedTypes =
1532 {
1533 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001534 DataType::QAsymmU8,
1535 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001536 };
1537
1538 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1539 "Reference Slice: input type not supported");
1540
1541 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1542 "Reference Slice: output type not supported");
1543
1544 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1545 "Reference Slice: input and output types are mismatched");
1546
1547 return supported;
1548}
1549
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001550bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1551 const TensorInfo& output,
1552 const SoftmaxDescriptor& descriptor,
1553 Optional<std::string&> reasonIfUnsupported) const
1554{
Derek Lamberti901ea112019-12-10 22:07:09 +00001555 boost::ignore_unused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001556 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001557 std::array<DataType,4> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001558 {
1559 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001560 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001561 DataType::QAsymmU8,
1562 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001563 };
1564
1565 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001566 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001567
1568 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001569 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001570
1571 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001572 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001573
1574 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001575}
1576
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001577bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1578 const TensorInfo& output,
1579 const SpaceToBatchNdDescriptor& descriptor,
1580 Optional<std::string&> reasonIfUnsupported) const
1581{
Derek Lamberti901ea112019-12-10 22:07:09 +00001582 boost::ignore_unused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001583 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001584 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001585 {
1586 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001587 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001588 DataType::QAsymmU8,
1589 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001590 };
1591
1592 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1593 "Reference SpaceToBatchNd: input type not supported");
1594
1595 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1596 "Reference SpaceToBatchNd: output type not supported");
1597
1598 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1599 "Reference SpaceToBatchNd: input and output types are mismatched");
1600
1601 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001602}
1603
Keith Davisa57eccb2019-06-14 17:33:22 +01001604bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001605 const TensorInfo& output,
1606 const SpaceToDepthDescriptor& descriptor,
1607 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001608{
1609
1610 ignore_unused(descriptor);
1611 bool supported = true;
1612
Matthew Jackson9bff1442019-09-12 09:08:23 +01001613 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001614 {
1615 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001616 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001617 DataType::QAsymmU8,
1618 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001619 };
1620
1621 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1622 "Reference SpaceToDepth: input type not supported");
1623
1624 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1625 "Reference SpaceToDepth: output type not supported");
1626
1627 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1628 "Reference SpaceToDepth: input and output types are mismatched");
1629
1630 return supported;
1631}
1632
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001633bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1634 const ViewsDescriptor& descriptor,
1635 Optional<std::string&> reasonIfUnsupported) const
1636{
1637 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001638 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001639 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001640 {
1641 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001642 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001643 DataType::QAsymmU8,
1644 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001645 };
1646
1647 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1648 "Reference splitter: input type not supported");
1649
1650 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001651}
1652
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001653bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1654 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1655 const ViewsDescriptor& descriptor,
1656 Optional<std::string&> reasonIfUnsupported) const
1657{
1658 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001659 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001660 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001661 {
1662 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001663 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001664 DataType::QAsymmU8,
1665 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001666 };
1667
1668 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1669 "Reference splitter: output type not supported");
1670 for (const TensorInfo output : outputs)
1671 {
1672 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1673 "Reference splitter: input type not supported");
1674
1675 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1676 "Reference splitter: input and output types mismatched.");
1677 }
1678
1679 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001680}
1681
Matthew Jackson81e601c2019-07-11 12:07:09 +01001682bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1683 const TensorInfo& output,
1684 const StackDescriptor& descriptor,
1685 Optional<std::string&> reasonIfUnsupported) const
1686{
1687 ignore_unused(descriptor);
1688
1689 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001690 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001691 {
1692 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001693 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001694 DataType::QAsymmU8,
1695 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01001696 };
1697
1698 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1699 "Reference stack: output type not supported");
1700 for (const TensorInfo* input : inputs)
1701 {
1702 BOOST_ASSERT(input != nullptr);
1703 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1704 "Reference stack: input type not supported");
1705
1706 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1707 "Reference stack: input and output types mismatched.");
1708 }
1709
1710 return supported;
1711}
1712
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001713bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1714 const TensorInfo& output,
1715 const StridedSliceDescriptor& descriptor,
1716 Optional<std::string&> reasonIfUnsupported) const
1717{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001718 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001719 bool supported = true;
1720
1721 std::array<DataType,3> supportedTypes =
1722 {
1723 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001724 DataType::QAsymmU8,
1725 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001726 };
1727
1728 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1729 "Reference StridedSlice: input type not supported");
1730
1731 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1732 "Reference StridedSlice: output type not supported");
1733
1734 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1735 "Reference StridedSlice: input and output types are mismatched");
1736
1737 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001738}
1739
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001740bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1741 const TensorInfo& input1,
1742 const TensorInfo& output,
1743 Optional<std::string&> reasonIfUnsupported) const
1744{
Sadik Armagan2999a022019-04-09 14:20:12 +01001745 bool supported = true;
1746
Matthew Jackson9bff1442019-09-12 09:08:23 +01001747 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001748 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001749 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001750 DataType::QAsymmU8,
1751 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001752 };
1753
1754 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1755 "Reference subtraction: input 0 is not a supported type.");
1756
1757 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1758 "Reference subtraction: input 1 is not a supported type.");
1759
1760 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1761 "Reference subtraction: output is not a supported type.");
1762
1763 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1764 "Reference subtraction: input 0 and Input 1 types are mismatched");
1765
1766 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1767 "Reference subtraction: input and output types are mismatched");
1768
1769 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1770 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1771
1772 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001773}
1774
Matteo Martincighab9e5252019-06-13 17:27:46 +01001775bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1776 const TensorInfo& alpha,
1777 const TensorInfo& output,
1778 Optional<std::string&> reasonIfUnsupported) const
1779{
1780 bool supported = true;
1781
Matthew Jackson9bff1442019-09-12 09:08:23 +01001782 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001783 {
1784 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001785 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001786 DataType::QAsymmU8,
1787 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01001788 };
1789
1790 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1791 "PReLU: input is not a supported type.");
1792
1793 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1794 "PReLU: alpha is not a supported type.");
1795
1796 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1797 "PReLU: output is not a supported type.");
1798
1799 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1800 "PReLU: input, alpha and output types are mismatched");
1801
1802 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1803 "PReLU: shapes are not suitable for implicit broadcast");
1804
1805 return supported;
1806}
1807
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001808bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1809 const TensorInfo& output,
1810 const TransposeConvolution2dDescriptor& descriptor,
1811 const TensorInfo& weights,
1812 const Optional<TensorInfo>& biases,
1813 Optional<std::string&> reasonIfUnsupported) const
1814{
Derek Lamberti901ea112019-12-10 22:07:09 +00001815 boost::ignore_unused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001816 bool supported = true;
1817
Matthew Jackson252df3a2019-09-11 09:19:18 +01001818 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001819 {
1820 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001821 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001822 DataType::QAsymmU8,
1823 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001824 };
1825
1826 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1827 "Reference TransposeConvolution2d: input is not a supported type.");
1828
1829 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1830 "Reference TransposeConvolution2d: output is not a supported type.");
1831
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001832 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1833 "Reference TransposeConvolution2d: input and output types mismatched.");
1834
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001835
1836 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +00001837 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001838 {
1839 std::array<DataType, 2> supportedWeightTypes =
1840 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001841 DataType::QAsymmU8,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001842 DataType::QuantizedSymm8PerAxis
1843 };
1844
1845 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1846 "Reference TransposeConvolution2d: weights type not supported for "
1847 "quantized input.");
1848 }
1849 else
1850 {
1851 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1852 "Reference TransposeConvolution2d: weights is not a supported type.");
1853
1854 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1855 "Reference TransposeConvolution2d: input and weights types mismatched.");
1856 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001857
1858 if (biases.has_value())
1859 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001860 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001861 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001862 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001863 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001864 DataType::Signed32
1865 };
1866 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1867 "Reference TransposeConvolution2d: biases is not a supported type.");
1868 }
1869
1870 return supported;
1871}
1872
arovir011c7c81b2018-10-08 11:34:28 +01001873} // namespace armnn