blob: 26a61d45d571377c5b115c9142ac36d09bdf19b7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000011#include <armnn/BackendRegistry.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matteo Martincighe011d202019-11-28 11:35:47 +000013#include <armnnUtils/DataLayoutIndexed.hpp>
14
15#include <InternalTypes.hpp>
16#include <LayerSupportCommon.hpp>
17
Derek Lambertif674aa02019-08-01 15:56:25 +010018#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000019
Matteo Martincighe011d202019-11-28 11:35:47 +000020#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000022
Derek Lamberti50db4e82019-03-13 14:16:15 +000023#include <vector>
24#include <algorithm>
25#include <array>
26
telsoa014fcda012018-03-09 14:13:49 +000027using namespace boost;
28
29namespace armnn
30{
31
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010032namespace
33{
34
35template<typename Float32Func, typename Uint8Func, typename ... Params>
36bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
37 DataType dataType,
38 Float32Func floatFuncPtr,
39 Uint8Func uint8FuncPtr,
40 Params&&... params)
41{
42 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
43 dataType,
44 &FalseFunc<Params...>,
45 floatFuncPtr,
46 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000047 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000048 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010049 std::forward<Params>(params)...);
50}
51
52} // anonymous namespace
53
James Conroy4d1ff582019-06-10 17:06:39 +010054namespace
55{
56
57std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
58 unsigned int actual,
59 std::string& layerStr,
60 std::string& tensorName)
61{
62 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
63 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
64
65 return errorMsg;
66}
67
68} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000069
Sadik Armagan9199e582019-09-05 17:35:31 +010070bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
71 Optional<std::string&> reasonIfUnsupported) const
72{
73 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +010074 std::array<DataType,4> supportedTypes =
Sadik Armagan9199e582019-09-05 17:35:31 +010075 {
76 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +010077 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +000078 DataType::QAsymmU8,
79 DataType::QSymmS16
Sadik Armagan9199e582019-09-05 17:35:31 +010080 };
81
82 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
83 "Reference abs: input type not supported");
84
85 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
86 "Reference abs: output type not supported");
87
88 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
89 "Reference abs: input and output types not matching");
90
91 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
92 "Reference abs: input and output shapes have different number of total elements");
93
94 return supported;
95}
96
arovir011c7c81b2018-10-08 11:34:28 +010097bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
98 const TensorInfo& output,
99 const ActivationDescriptor& descriptor,
100 Optional<std::string&> reasonIfUnsupported) const
101{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000102 bool supported = true;
103
104 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100105 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000106 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100107 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000108 DataType::QAsymmU8,
109 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000110 };
111
112 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
113 "Reference activation: input type not supported.");
114
115 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
116 "Reference activation: output type not supported.");
117
118 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
119 "Reference activation: input and output types mismatched.");
120
121 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
122 "Reference activation: input and output shapes are of different rank.");
123
124
125 struct ActivationFunctionSupported : public Rule
126 {
127 ActivationFunctionSupported(const ActivationDescriptor& desc)
128 {
129 switch(desc.m_Function)
130 {
131 case ActivationFunction::Abs:
132 case ActivationFunction::BoundedReLu:
133 case ActivationFunction::LeakyReLu:
134 case ActivationFunction::Linear:
135 case ActivationFunction::ReLu:
136 case ActivationFunction::Sigmoid:
137 case ActivationFunction::SoftReLu:
138 case ActivationFunction::Sqrt:
139 case ActivationFunction::Square:
140 case ActivationFunction::TanH:
141 {
142 m_Res = true;
143 break;
144 }
145 default:
146 {
147 m_Res = false;
148 break;
149 }
150 }
151 }
152 };
153
154 // Function is supported
155 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
156 "Reference activation: function not supported.");
157
158 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100159}
160
161bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
162 const TensorInfo& input1,
163 const TensorInfo& output,
164 Optional<std::string&> reasonIfUnsupported) const
165{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000166 bool supported = true;
167
Matthew Jackson252df3a2019-09-11 09:19:18 +0100168 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000169 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100170 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000171 DataType::QAsymmU8,
172 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000173 };
174
175 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
176 "Reference addition: input 0 is not a supported type.");
177
178 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
179 "Reference addition: input 1 is not a supported type.");
180
181 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
182 "Reference addition: output is not a supported type.");
183
184 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
185 "Reference addition: input 0 and Input 1 types are mismatched");
186
187 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
188 "Reference addition: input and output types are mismatched");
189
190 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
191 "Reference addition: shapes are not suitable for implicit broadcast.");
192
193 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100194}
195
Nikhil Raj68c2c902019-09-19 11:21:11 +0100196bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
197 const armnn::ArgMinMaxDescriptor &descriptor,
198 armnn::Optional<std::string &> reasonIfUnsupported) const
199{
200 ignore_unused(descriptor);
201
Francis Murtagh1939df52019-11-13 15:21:09 +0000202 std::array<DataType, 4> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100203 {
204 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000205 DataType::QAsymmU8,
206 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000207 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100208 };
209
210 bool supported = true;
211
212 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
213 "Reference ArgMinMax: input is not a supported type.");
214 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
215 "Reference ArgMinMax: output type not supported");
216
217 return supported;
218}
219
arovir011c7c81b2018-10-08 11:34:28 +0100220bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
221 const TensorInfo& output,
222 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100223 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100224 const TensorInfo& beta,
225 const TensorInfo& gamma,
226 const BatchNormalizationDescriptor& descriptor,
227 Optional<std::string&> reasonIfUnsupported) const
228{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100229 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100230
Matthew Jackson9bff1442019-09-12 09:08:23 +0100231 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100232 {
233 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100234 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000235 DataType::QAsymmU8,
236 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100237 };
238
239 bool supported = true;
240
241 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
242 "Reference batch normalization: input is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
245 "Reference batch normalization: output is not a supported type.");
246
247 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
248 "Reference batch normalization: input and output types are mismatched");
249
250 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
251 "Reference batch normalization: mean is not a supported type.");
252
253 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
254 "Reference batch normalization: variance is not a supported type.");
255
256 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
257 "Reference batch normalization: beta is not a supported type.");
258
259 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
260 "Reference batch normalization: gamma is not a supported type.");
261
262 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100263}
264
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000265bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
266 const TensorInfo& output,
267 const BatchToSpaceNdDescriptor& descriptor,
268 Optional<std::string&> reasonIfUnsupported) const
269{
270 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100271
272 bool supported = true;
273
274 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
275 std::string inputTensorStr = "input";
276 std::string outputTensorStr = "output";
277
278 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100279 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100280 {
281 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100282 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000283 DataType::QAsymmU8,
284 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100285 };
286
287 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
288 "Reference BatchToSpaceNd: input type not supported.");
289
290 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
291 "Reference BatchToSpaceNd: output type not supported.");
292
293 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
294 "Reference BatchToSpaceNd: input and output types mismatched.");
295
296 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
297 reasonIfUnsupported,
298 CreateIncorrectDimensionsErrorMsg(4,
299 output.GetNumDimensions(),
300 batchToSpaceNdLayerStr,
301 outputTensorStr).data());
302
303 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
304 reasonIfUnsupported,
305 CreateIncorrectDimensionsErrorMsg(4,
306 input.GetNumDimensions(),
307 batchToSpaceNdLayerStr,
308 inputTensorStr).data());
309
310 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000311}
312
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100313bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
314 const TensorInfo& input1,
315 const TensorInfo& output,
316 const ComparisonDescriptor& descriptor,
317 Optional<std::string&> reasonIfUnsupported) const
318{
319 boost::ignore_unused(descriptor);
320
321 std::array<DataType, 4> supportedInputTypes =
322 {
323 DataType::Float32,
324 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000325 DataType::QAsymmU8,
326 DataType::QSymmS16
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100327 };
328
329 bool supported = true;
330 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
331 "Reference comparison: input 0 is not a supported type");
332
333 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
334 "Reference comparison: input 0 and Input 1 types are mismatched");
335
336 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
337 "Reference comparison: output is not of type Boolean");
338
339 return supported;
340}
341
Jim Flynn906f9462019-05-10 13:55:21 +0100342bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
343 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100344 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100345 Optional<std::string&> reasonIfUnsupported) const
346{
Jim Flynne242f2d2019-05-22 14:24:13 +0100347 ignore_unused(descriptor);
348
349 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100350 std::array<DataType,4> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100351 {
352 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100353 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000354 DataType::QAsymmU8,
355 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100356 };
357
358 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
359 "Reference concatenation: output type not supported");
360 for (const TensorInfo* input : inputs)
361 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100362 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100363 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
364 "Reference concatenation: input type not supported");
365
366 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
367 "Reference concatenation: input and output types mismatched.");
368 }
369
370 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100371}
372
arovir011c7c81b2018-10-08 11:34:28 +0100373bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
374 Optional<std::string&> reasonIfUnsupported) const
375{
Jim Flynne242f2d2019-05-22 14:24:13 +0100376 std::array<DataType,4> supportedTypes =
377 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100378 DataType::Float32,
379 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000380 DataType::QAsymmU8,
381 DataType::QSymmS16
Nina Drozd58ef2c62019-05-16 12:09:18 +0100382 };
383
384 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
385 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100386}
387
388bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
389 const TensorInfo& output,
390 Optional<std::string&> reasonIfUnsupported) const
391{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100392 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
393 input.GetDataType(),
394 &TrueFunc<>,
395 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000396 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000397 &FalseFuncI32<>,
398 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100399 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
400 output.GetDataType(),
401 &FalseOutputFuncF16<>,
402 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000403 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000404 &FalseFuncI32<>,
405 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100406}
407
408bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
409 const TensorInfo& output,
410 Optional<std::string&> reasonIfUnsupported) const
411{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100412 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
413 input.GetDataType(),
414 &FalseInputFuncF16<>,
415 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000416 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000417 &FalseFuncI32<>,
418 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100419 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
420 output.GetDataType(),
421 &TrueFunc<>,
422 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000423 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000424 &FalseFuncI32<>,
425 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100426}
427
428bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
429 const TensorInfo& output,
430 const Convolution2dDescriptor& descriptor,
431 const TensorInfo& weights,
432 const Optional<TensorInfo>& biases,
433 Optional<std::string&> reasonIfUnsupported) const
434{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100435 bool supported = true;
436
437 // Define supported types.
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000438 std::array<DataType,4> supportedTypes =
439 {
440 DataType::Float32,
441 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000442 DataType::QAsymmU8,
443 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100444 };
445
446 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100447 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100448
449 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100450 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100451
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100452 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100453 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100454
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000455 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +0000456 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000457 {
458 std::array<DataType, 2> supportedWeightTypes =
459 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000460 DataType::QAsymmU8,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000461 DataType::QuantizedSymm8PerAxis
462 };
463
464 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
465 "Reference convolution2d: weights type not supported for quantized input.");
466 }
467 else
468 {
469 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
470 "Reference convolution2d: weights is not a supported type.");
471
472 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
473 "Reference convolution2d: input and weights types mismatched.");
474 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100475
476 if (biases.has_value())
477 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000478 std::array<DataType,3> biasesSupportedTypes =
479 {
480 DataType::Float32,
481 DataType::Float16,
482 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100483 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000484
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100485 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100486 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100487 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100488 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100489
490 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100491}
492
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000493bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
494 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000495 Optional<std::string&> reasonIfUnsupported) const
496{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100497 bool supported = true;
498
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000499 std::array<DataType, 5> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100500 {
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000501 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100502 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000503 DataType::QAsymmU8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000504 DataType::QSymmS16,
505 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100506 };
507
508 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
509 "Reference debug: input type not supported");
510
511 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
512 "Reference debug: output type not supported");
513
514 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
515 "Reference debug: input and output types are mismatched");
516
517 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000518}
519
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100520bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
521 const TensorInfo& output,
522 const DepthToSpaceDescriptor& descriptor,
523 Optional<std::string&> reasonIfUnsupported) const
524{
525 ignore_unused(descriptor);
526 bool supported = true;
527
528 std::array<DataType,4> supportedTypes =
529 {
530 DataType::Float32,
531 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000532 DataType::QAsymmU8,
533 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100534 };
535
536 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
537 "Reference DepthToSpace: input type not supported");
538
539 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
540 "Reference DepthToSpace: output type not supported");
541
542 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
543 "Reference DepthToSpace: input and output types are mismatched");
544
545 return supported;
546}
547
arovir011c7c81b2018-10-08 11:34:28 +0100548bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
549 const TensorInfo& output,
550 const DepthwiseConvolution2dDescriptor& descriptor,
551 const TensorInfo& weights,
552 const Optional<TensorInfo>& biases,
553 Optional<std::string&> reasonIfUnsupported) const
554{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100555 bool supported = true;
556
557 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100558 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100559 {
560 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100561 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000562 DataType::QAsymmU8,
563 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100564 };
565
566 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
567 "Reference DepthwiseConvolution2d: input is not a supported type.");
568
569 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
570 "Reference DepthwiseConvolution2d: output is not a supported type.");
571
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100572 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
573 "Reference DepthwiseConvolution2d: input and output types mismatched.");
574
Teresa Charlind8df0262019-11-11 12:28:15 +0000575 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +0000576 if (inputType == DataType::QAsymmU8)
Teresa Charlind8df0262019-11-11 12:28:15 +0000577 {
578 std::array<DataType, 2> supportedWeightTypes =
579 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000580 DataType::QAsymmU8,
Teresa Charlind8df0262019-11-11 12:28:15 +0000581 DataType::QuantizedSymm8PerAxis
582 };
583
584 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
585 "Reference convolution2d: weights type not supported for quantized input.");
586 }
587 else
588 {
589 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
590 "Reference DepthwiseConvolution2d: weights is not a supported type.");
591
592 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
593 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
594 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100595
596 if (biases.has_value())
597 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100598 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100599 {
600 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100601 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100602 DataType::Signed32
603 };
604 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
605 "Reference DepthwiseConvolution2d: biases is not a supported type.");
606 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100607 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100608
609 return supported;
610
arovir011c7c81b2018-10-08 11:34:28 +0100611}
612
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000613bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
614 const TensorInfo& output,
615 Optional<std::string&> reasonIfUnsupported) const
616{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100617 bool supported = true;
618
Finn Williamsfd271062019-12-04 14:27:27 +0000619 std::array<DataType,3> supportedInputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000620 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000621 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000622 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100623 };
624
625 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
626 "Reference dequantize: input type not supported.");
627
Jan Eilersf7107932019-11-01 11:09:36 +0000628 std::array<DataType,2> supportedOutputTypes = {
629 DataType::Float32,
630 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100631 };
632
633 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
634 "Reference dequantize: output type not supported.");
635
636 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
637 "Reference dequantize: input and output shapes have different num total elements.");
638
639 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000640}
641
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000642bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
643 const TensorInfo& scores,
644 const TensorInfo& anchors,
645 const TensorInfo& detectionBoxes,
646 const TensorInfo& detectionClasses,
647 const TensorInfo& detectionScores,
648 const TensorInfo& numDetections,
649 const DetectionPostProcessDescriptor& descriptor,
650 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000651{
Derek Lamberti901ea112019-12-10 22:07:09 +0000652 boost::ignore_unused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
653
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100654 bool supported = true;
655
Mike Kelly4992c342019-08-14 11:33:11 +0100656 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100657 {
658 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000659 DataType::QAsymmU8,
660 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100661 };
662
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000663 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100664 "Reference DetectionPostProcess: input 0 is not a supported type.");
665
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000666 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100667 "Reference DetectionPostProcess: input 1 is not a supported type.");
668
669 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000670}
671
Pablo Tellof0bd6832019-04-26 17:58:13 +0100672bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
673 const TensorInfo& output,
674 const DepthwiseConvolution2dDescriptor& descriptor,
675 const TensorInfo& weights,
676 const Optional<TensorInfo>& biases,
677 Optional<std::string&> reasonIfUnsupported) const
678{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100679 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100680}
681
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100682bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100683 const TensorInfo& input1,
684 const TensorInfo& output,
685 Optional<std::string&> reasonIfUnsupported) const
686{
Sadik Armagan2999a022019-04-09 14:20:12 +0100687 bool supported = true;
688
Matthew Jackson9bff1442019-09-12 09:08:23 +0100689 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100690 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100691 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000692 DataType::QAsymmU8,
693 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +0100694 };
695
696 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
697 "Reference division: input 0 is not a supported type.");
698
699 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
700 "Reference division: input 1 is not a supported type.");
701
702 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
703 "Reference division: output is not a supported type.");
704
705 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
706 "Reference division: input 0 and Input 1 types are mismatched");
707
708 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
709 "Reference division: input and output types are mismatched");
710
711 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
712 "Reference division: shapes are not suitable for implicit broadcast.");
713
714 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100715}
716
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000717bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
718 const TensorInfo& input1,
719 const TensorInfo& output,
720 Optional<std::string&> reasonIfUnsupported) const
721{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100722 return IsComparisonSupported(input0,
723 input1,
724 output,
725 ComparisonDescriptor(ComparisonOperation::Equal),
726 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000727}
728
arovir011c7c81b2018-10-08 11:34:28 +0100729bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
730 const FakeQuantizationDescriptor& descriptor,
731 Optional<std::string&> reasonIfUnsupported) const
732{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100733 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100734 bool supported = true;
735
736 std::array<DataType,1> supportedTypes =
737 {
738 DataType::Float32
739 };
740
741 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
742 "Reference fake quantization: input type not supported.");
743
744 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100745}
746
747bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
748 const TensorInfo& output,
749 Optional<std::string&> reasonIfUnsupported) const
750{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100751 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100752 bool supported = true;
753
Matthew Jackson9bff1442019-09-12 09:08:23 +0100754 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100755 {
James Conroyb40d7102019-06-04 12:32:09 +0100756 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100757 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000758 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +0100759 };
760
761 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
762 "Reference Floor: input type not supported.");
763
764 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
765 "Reference Floor: output type not supported.");
766
767 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100768}
769
770bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
771 const TensorInfo& output,
772 const TensorInfo& weights,
773 const TensorInfo& biases,
774 const FullyConnectedDescriptor& descriptor,
775 Optional<std::string&> reasonIfUnsupported) const
776{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100777 bool supported = true;
778
779 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100780 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100781 {
782 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100783 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000784 DataType::QAsymmU8,
785 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100786 };
787
788 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
789 "Reference Fully Connected: input type not supported.");
790
791 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
792 "Reference Fully Connected: output type not supported.");
793
794 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
795 "Reference Fully Connected: input and output types mismatched.");
796
797 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
798 "Reference Fully Connected: weights type not supported.");
799
800 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
801 "Reference Fully Connected: input and weight types mismatched.");
802
803 if (descriptor.m_BiasEnabled)
804 {
805 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100806 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100807 supportedBiasTypes =
808 {
809 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100810 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100811 DataType::Signed32
812 };
813
814 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
815 "Reference Fully Connected: bias type not supported.");
816
817 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
818 "Reference Fully Connected: bias and weight types mismatch.");
819
820 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
821 "Reference Fully Connected: bias type inferred from weights is incompatible.");
822
823 }
824
825 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100826}
827
narpra014951d842019-01-18 16:53:53 +0000828bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
829 const armnn::TensorInfo& input1,
830 const armnn::TensorInfo& output,
831 armnn::Optional<std::string&> reasonIfUnsupported) const
832{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100833 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100834 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100835 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100836 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100837 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000838 DataType::QAsymmU8,
839 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100840 };
841
842 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
843 "Reference Gather: input type not supported");
844
845 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
846 "Reference Gather: output type not supported");
847
848 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
849 "Reference Gather: indices (input1) type not supported");
850
851 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
852 "Reference Gather: input and output types not matching");
853
854 return supported;
narpra014951d842019-01-18 16:53:53 +0000855}
856
FrancisMurtagh878f0232018-12-19 10:56:15 +0000857bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
858 const TensorInfo& input1,
859 const TensorInfo& output,
860 Optional<std::string&> reasonIfUnsupported) const
861{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100862 return IsComparisonSupported(input0,
863 input1,
864 output,
865 ComparisonDescriptor(ComparisonOperation::Greater),
866 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000867}
868
Derek Lamberti901ea112019-12-10 22:07:09 +0000869bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
870 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +0100871{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100872 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100873}
874
Kevin May09ca49c2019-10-09 12:37:34 +0100875bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
876 const TensorInfo& output,
877 const InstanceNormalizationDescriptor& descriptor,
878 Optional<std::string&> reasonIfUnsupported) const
879{
880 ignore_unused(descriptor);
881 // Define supported types
882 std::array<DataType, 4> supportedTypes =
883 {
884 DataType::Float32,
885 DataType::Float16
886 };
887
888 bool supported = true;
889
890 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
891 "Reference Instance Normalization: input type not supported.");
892
893 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
894 "Reference Instance Normalization: output type not supported.");
895
896 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
897 "Reference Instance Normalization: input and output types mismatched.");
898
899 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
900 "Reference Instance Normalization: input and output shapes have different "
901 "num total elements.");
902
903 return supported;
904}
905
arovir011c7c81b2018-10-08 11:34:28 +0100906bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
907 const TensorInfo& output,
908 const L2NormalizationDescriptor& descriptor,
909 Optional<std::string&> reasonIfUnsupported) const
910{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100911 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100912 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100913 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100914 {
915 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100916 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000917 DataType::QAsymmU8,
918 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100919 };
920
921 bool supported = true;
922
923 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
924 "Reference L2normalization: input type not supported.");
925
926 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
927 "Reference L2normalization: output type not supported.");
928
929 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
930 "Reference L2normalization: input and output types mismatched.");
931
932 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
933 "Reference L2normalization: input and output shapes have different "
934 "num total elements.");
935
936 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100937}
938
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100939bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
940 const TensorInfo& output,
941 const LogSoftmaxDescriptor& descriptor,
942 Optional<std::string&> reasonIfUnsupported) const
943{
944 ignore_unused(descriptor);
945
946 std::array<DataType, 2> supportedTypes =
947 {
948 DataType::Float32,
949 DataType::Float16
950 };
951
952 bool supported = true;
953 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
954 "Reference LogSoftmax: input type not supported");
955
956 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
957 "Reference LogSoftmax: output type not supported");
958
959 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
960 "Reference LogSoftmax: input and output types do not match");
961
962 return supported;
963}
964
arovir011c7c81b2018-10-08 11:34:28 +0100965bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
966 const TensorInfo& outputStateIn,
967 const TensorInfo& cellStateIn,
968 const TensorInfo& scratchBuffer,
969 const TensorInfo& outputStateOut,
970 const TensorInfo& cellStateOut,
971 const TensorInfo& output,
972 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100973 const LstmInputParamsInfo& paramsInfo,
974 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100975{
telsoa01c577f2c2018-08-31 09:22:23 +0100976 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100977 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100978
979 bool supported = true;
980
981 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100982 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000983 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100984 };
985
Jan Eilersd01a83c2019-07-03 18:20:40 +0100986 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100987 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
988 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100989 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
990 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100991 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
992 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100993 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
994 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100995 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
996 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100997 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
998 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100999 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1000 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001001 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001002 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001003 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001004 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001005 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001006 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001007 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001008 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001009 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001010 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001011 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001012 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001013 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001014 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001015 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001016 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001017 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001018 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001019 "Reference Lstm: input and OutputGateBias types are mismatched");
1020 if (!descriptor.m_CifgEnabled)
1021 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001022 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001023 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001024 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001025 reasonIfUnsupported,
1026 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001027 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001028 "Reference Lstm: input and InputGateBias types are mismatched");
1029 if (descriptor.m_PeepholeEnabled)
1030 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001031 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001032 reasonIfUnsupported,
1033 "Reference Lstm: input and CellToInputWeights types are mismatched");
1034 }
1035 }
1036 if (descriptor.m_PeepholeEnabled)
1037 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001038 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001039 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001040 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001041 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1042 }
1043 if (descriptor.m_ProjectionEnabled)
1044 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001045 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001046 "Reference Lstm: input and mProjectionWeights types are mismatched");
1047 if (paramsInfo.m_ProjectionBias != nullptr)
1048 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001049 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001050 "Reference Lstm: input and ProjectionBias types are mismatched");
1051 }
1052 }
1053 if (descriptor.m_LayerNormEnabled)
1054 {
1055 if (!descriptor.m_CifgEnabled)
1056 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001057 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001058 reasonIfUnsupported,
1059 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1060 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001061 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001062 reasonIfUnsupported,
1063 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001064 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001065 reasonIfUnsupported,
1066 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001067 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001068 reasonIfUnsupported,
1069 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1070 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001071
1072 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001073}
1074
saoste012df12b32018-11-28 16:57:20 +00001075bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1076 const TensorInfo& input1,
1077 const TensorInfo& output,
1078 Optional<std::string&> reasonIfUnsupported) const
1079{
Sadik Armagan2999a022019-04-09 14:20:12 +01001080 bool supported = true;
1081
Matthew Jackson9bff1442019-09-12 09:08:23 +01001082 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001083 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001084 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001085 DataType::QAsymmU8,
1086 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001087 };
1088
1089 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1090 "Reference maximum: input 0 is not a supported type.");
1091
1092 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1093 "Reference maximum: input 1 is not a supported type.");
1094
1095 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1096 "Reference maximum: output is not a supported type.");
1097
1098 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1099 "Reference maximum: input 0 and Input 1 types are mismatched");
1100
1101 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1102 "Reference maximum: input and output types are mismatched");
1103
1104 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1105 "Reference maximum: shapes are not suitable for implicit broadcast.");
1106
1107 return supported;
saoste012df12b32018-11-28 16:57:20 +00001108}
1109
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001110bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1111 const TensorInfo& output,
1112 const MeanDescriptor& descriptor,
1113 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001114{
James Conroy4d1ff582019-06-10 17:06:39 +01001115 bool supported = true;
1116 std::string meanLayerStr = "Mean";
1117 std::string outputTensorStr = "output";
1118
Matthew Jackson252df3a2019-09-11 09:19:18 +01001119 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001120 {
1121 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001122 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001123 DataType::QAsymmU8,
1124 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001125 };
1126
1127 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1128 "Reference Mean: input type not supported.");
1129
1130 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1131 "Reference Mean: input and output types are mismatched");
1132
1133 if (descriptor.m_KeepDims)
1134 {
1135 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1136 reasonIfUnsupported,
1137 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1138 output.GetNumDimensions(),
1139 meanLayerStr, outputTensorStr).data());
1140 }
1141 else if (descriptor.m_Axis.empty())
1142 {
1143 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1144 reasonIfUnsupported,
1145 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1146 meanLayerStr, outputTensorStr).data());
1147 }
1148 else
1149 {
1150 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1151
1152 if (outputDim > 0)
1153 {
1154 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1155 reasonIfUnsupported,
1156 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1157 meanLayerStr, outputTensorStr).data());
1158 }
1159 else
1160 {
1161 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1162 reasonIfUnsupported,
1163 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1164 meanLayerStr, outputTensorStr).data());
1165 }
1166 }
1167
1168 return supported;
narpra0132b90462018-09-13 11:07:48 +01001169}
1170
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001171bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001172 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001173 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001174 Optional<std::string&> reasonIfUnsupported) const
1175{
Jim Flynne242f2d2019-05-22 14:24:13 +01001176 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001177}
1178
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001179bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1180 const TensorInfo &output,
1181 Optional<std::string &> reasonIfUnsupported) const
1182{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001183 bool supported = true;
1184
1185 std::array<DataType,5> supportedTypes =
1186 {
1187 DataType::Float32,
1188 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001189 DataType::QAsymmU8,
1190 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001191 DataType::Boolean
1192 };
1193
1194 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1195 "Reference MemCopy: input type not supported");
1196
1197 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1198 "Reference MemCopy: output type not supported");
1199
1200 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1201 "Reference MemCopy: input and output types are mismatched");
1202
1203 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001204}
1205
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001206bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1207 const TensorInfo& input1,
1208 const TensorInfo& output,
1209 Optional<std::string&> reasonIfUnsupported) const
1210{
Sadik Armagan2999a022019-04-09 14:20:12 +01001211 bool supported = true;
1212
Matthew Jackson9bff1442019-09-12 09:08:23 +01001213 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001214 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001215 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001216 DataType::QAsymmU8,
1217 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001218 };
1219
1220 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1221 "Reference minimum: input 0 is not a supported type.");
1222
1223 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1224 "Reference minimum: input 1 is not a supported type.");
1225
1226 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1227 "Reference minimum: output is not a supported type.");
1228
1229 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1230 "Reference minimum: input 0 and Input 1 types are mismatched");
1231
1232 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1233 "Reference minimum: input and output types are mismatched");
1234
1235 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1236 "Reference minimum: shapes are not suitable for implicit broadcast.");
1237
1238 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001239}
1240
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001241bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1242 const TensorInfo& input1,
1243 const TensorInfo& output,
1244 Optional<std::string&> reasonIfUnsupported) const
1245{
Sadik Armagan2999a022019-04-09 14:20:12 +01001246 bool supported = true;
1247
Matthew Jackson252df3a2019-09-11 09:19:18 +01001248 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001249 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001250 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001251 DataType::QAsymmU8,
1252 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001253 };
1254
1255 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1256 "Reference multiplication: input 0 is not a supported type.");
1257
1258 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1259 "Reference multiplication: input 1 is not a supported type.");
1260
1261 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1262 "Reference multiplication: output is not a supported type.");
1263
1264 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1265 "Reference multiplication: input 0 and Input 1 types are mismatched");
1266
1267 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1268 "Reference multiplication: input and output types are mismatched");
1269
1270 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1271 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1272
1273 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001274}
1275
1276bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1277 const TensorInfo& output,
1278 const NormalizationDescriptor& descriptor,
1279 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001280{
Nina Drozd661dfa72018-10-02 11:14:17 +01001281 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001282
1283 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001284 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001285 {
1286 DataType::Float16,
1287 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001288 DataType::QAsymmU8,
1289 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001290 };
1291
1292 bool supported = true;
1293
1294 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1295 "Reference normalization: input type not supported.");
1296
1297 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1298 "Reference normalization: output type not supported.");
1299
1300 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1301 "Reference normalization: input and output shapes have different "
1302 "num total elements.");
1303
1304 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001305}
1306
Derek Lamberti901ea112019-12-10 22:07:09 +00001307bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1308 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001309{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001310 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001311}
1312
1313bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1314 const TensorInfo& output,
1315 const PadDescriptor& descriptor,
1316 Optional<std::string&> reasonIfUnsupported) const
1317{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001318 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001319 bool supported = true;
1320
1321 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001322 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001323 {
1324 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001325 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001326 DataType::QAsymmU8,
1327 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001328 };
1329
1330 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1331 "Reference pad: input is not a supported type.");
1332
1333 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1334 "Reference pad: output is not a supported type.");
1335
1336 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1337 "Reference pad: input and output types are mismatched.");
1338
1339 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001340}
1341
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001342bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1343 const TensorInfo& output,
1344 const PermuteDescriptor& descriptor,
1345 Optional<std::string&> reasonIfUnsupported) const
1346{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001347 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001348 bool supported = true;
1349
1350 // Define supported output and inputs types.
1351 std::array<DataType,3> supportedTypes =
1352 {
1353 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001354 DataType::QAsymmU8,
1355 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001356 };
1357
1358 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1359 "Reference permute: input is not a supported type.");
1360
1361 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1362 "Reference permute: output is not a supported type.");
1363
1364 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1365 "Reference permute: input and output types are mismatched.");
1366
1367 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001368}
1369
1370bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1371 const TensorInfo& output,
1372 const Pooling2dDescriptor& descriptor,
1373 Optional<std::string&> reasonIfUnsupported) const
1374{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001375 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001376 bool supported = true;
1377
1378 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001379 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001380 {
1381 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001382 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001383 DataType::QAsymmU8,
1384 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001385 };
1386
1387 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1388 "Reference poolind2d: input is not a supported type.");
1389
1390 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1391 "Reference poolind2d: output is not a supported type.");
1392
1393 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1394 "Reference poolind2d: input and output types are mismatched.");
1395
1396 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001397}
1398
Derek Lamberti5f400d62019-03-25 15:41:58 +00001399bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1400 const TensorInfo& output,
1401 Optional<std::string&> reasonIfUnsupported) const
1402{
1403 bool supported = true;
1404
Finn Williamsfd271062019-12-04 14:27:27 +00001405 // Define supported input types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001406 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001407 DataType::Float32,
1408 };
1409
1410 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1411 "Reference quantize: input type not supported.");
1412
1413 // Define supported output types.
Finn Williamsfd271062019-12-04 14:27:27 +00001414 std::array<DataType,3> supportedOutputTypes = {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001415 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001416 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001417 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001418 };
1419 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1420 "Reference quantize: output type not supported.");
1421
1422 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1423 "Reference quantize: input and output shapes have different num total elements.");
1424
1425 return supported;
1426}
1427
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001428bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001429 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001430 Optional<std::string&> reasonIfUnsupported) const
1431{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001432 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001433 // Define supported output types.
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001434 std::array<DataType,5> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001435 {
1436 DataType::Float32,
1437 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001438 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001439 DataType::QAsymmU8,
1440 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001441 };
1442 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1443 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001444}
1445
1446bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001447 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001448 Optional<std::string&> reasonIfUnsupported) const
1449{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001450 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001451 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001452 {
1453 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001454 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001455 DataType::QAsymmU8,
1456 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001457 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001458
1459 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1460 "Reference ResizeBilinear: input type not supported");
1461
1462 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1463 "Reference ResizeBilinear: output type not supported");
1464
1465 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1466 "Reference ResizeBilinear: input and output types not matching");
1467
1468 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001469}
1470
Teresa Charlin970f43b2019-07-01 13:51:07 +01001471bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1472 const TensorInfo& output,
1473 const ResizeDescriptor& descriptor,
1474 Optional<std::string&> reasonIfUnsupported) const
1475{
Derek Lamberti901ea112019-12-10 22:07:09 +00001476 boost::ignore_unused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001477 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001478 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001479 {
1480 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001481 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001482 DataType::QAsymmU8,
1483 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001484 };
1485
1486 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1487 "Reference Resize: input type not supported");
1488
1489 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1490 "Reference Resize: output type not supported");
1491
1492 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1493 "Reference Resize: input and output types not matching");
1494
1495 return supported;
1496}
1497
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001498bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1499 const TensorInfo& output,
1500 Optional<std::string&> reasonIfUnsupported) const
1501{
nikraj010421e7f2019-06-14 09:40:34 +01001502 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001503 std::array<DataType,4> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001504 {
1505 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001506 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001507 DataType::QAsymmU8,
1508 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01001509 };
1510
1511 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1512 "Reference rsqrt: input type not supported");
1513
1514 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1515 "Reference rsqrt: output type not supported");
1516
1517 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1518 "Reference rsqrt: input and output types not matching");
1519
1520 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1521 "Reference Rsqrt: input and output shapes have different number of total elements");
1522
1523 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001524}
1525
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001526bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1527 const TensorInfo& output,
1528 const SliceDescriptor& descriptor,
1529 Optional<std::string&> reasonIfUnsupported) const
1530{
Derek Lamberti901ea112019-12-10 22:07:09 +00001531 boost::ignore_unused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001532 bool supported = true;
1533
1534 std::array<DataType, 3> supportedTypes =
1535 {
1536 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001537 DataType::QAsymmU8,
1538 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001539 };
1540
1541 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1542 "Reference Slice: input type not supported");
1543
1544 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1545 "Reference Slice: output type not supported");
1546
1547 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1548 "Reference Slice: input and output types are mismatched");
1549
1550 return supported;
1551}
1552
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001553bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1554 const TensorInfo& output,
1555 const SoftmaxDescriptor& descriptor,
1556 Optional<std::string&> reasonIfUnsupported) const
1557{
Derek Lamberti901ea112019-12-10 22:07:09 +00001558 boost::ignore_unused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001559 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001560 std::array<DataType,4> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001561 {
1562 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001563 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001564 DataType::QAsymmU8,
1565 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001566 };
1567
1568 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001569 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001570
1571 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001572 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001573
1574 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001575 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001576
1577 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001578}
1579
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001580bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1581 const TensorInfo& output,
1582 const SpaceToBatchNdDescriptor& descriptor,
1583 Optional<std::string&> reasonIfUnsupported) const
1584{
Derek Lamberti901ea112019-12-10 22:07:09 +00001585 boost::ignore_unused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001586 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001587 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001588 {
1589 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001590 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001591 DataType::QAsymmU8,
1592 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001593 };
1594
1595 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1596 "Reference SpaceToBatchNd: input type not supported");
1597
1598 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1599 "Reference SpaceToBatchNd: output type not supported");
1600
1601 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1602 "Reference SpaceToBatchNd: input and output types are mismatched");
1603
1604 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001605}
1606
Keith Davisa57eccb2019-06-14 17:33:22 +01001607bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001608 const TensorInfo& output,
1609 const SpaceToDepthDescriptor& descriptor,
1610 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001611{
1612
1613 ignore_unused(descriptor);
1614 bool supported = true;
1615
Matthew Jackson9bff1442019-09-12 09:08:23 +01001616 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001617 {
1618 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001619 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001620 DataType::QAsymmU8,
1621 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001622 };
1623
1624 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1625 "Reference SpaceToDepth: input type not supported");
1626
1627 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1628 "Reference SpaceToDepth: output type not supported");
1629
1630 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1631 "Reference SpaceToDepth: input and output types are mismatched");
1632
1633 return supported;
1634}
1635
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001636bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1637 const ViewsDescriptor& descriptor,
1638 Optional<std::string&> reasonIfUnsupported) const
1639{
1640 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001641 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001642 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001643 {
1644 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001645 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001646 DataType::QAsymmU8,
1647 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001648 };
1649
1650 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1651 "Reference splitter: input type not supported");
1652
1653 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001654}
1655
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001656bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1657 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1658 const ViewsDescriptor& descriptor,
1659 Optional<std::string&> reasonIfUnsupported) const
1660{
1661 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001662 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001663 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001664 {
1665 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001666 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001667 DataType::QAsymmU8,
1668 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001669 };
1670
1671 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1672 "Reference splitter: output type not supported");
1673 for (const TensorInfo output : outputs)
1674 {
1675 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1676 "Reference splitter: input type not supported");
1677
1678 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1679 "Reference splitter: input and output types mismatched.");
1680 }
1681
1682 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001683}
1684
Matthew Jackson81e601c2019-07-11 12:07:09 +01001685bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1686 const TensorInfo& output,
1687 const StackDescriptor& descriptor,
1688 Optional<std::string&> reasonIfUnsupported) const
1689{
1690 ignore_unused(descriptor);
1691
1692 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001693 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001694 {
1695 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001696 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001697 DataType::QAsymmU8,
1698 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01001699 };
1700
1701 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1702 "Reference stack: output type not supported");
1703 for (const TensorInfo* input : inputs)
1704 {
1705 BOOST_ASSERT(input != nullptr);
1706 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1707 "Reference stack: input type not supported");
1708
1709 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1710 "Reference stack: input and output types mismatched.");
1711 }
1712
1713 return supported;
1714}
1715
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001716bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1717 const TensorInfo& output,
1718 const StridedSliceDescriptor& descriptor,
1719 Optional<std::string&> reasonIfUnsupported) const
1720{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001721 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001722 bool supported = true;
1723
1724 std::array<DataType,3> supportedTypes =
1725 {
1726 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001727 DataType::QAsymmU8,
1728 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001729 };
1730
1731 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1732 "Reference StridedSlice: input type not supported");
1733
1734 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1735 "Reference StridedSlice: output type not supported");
1736
1737 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1738 "Reference StridedSlice: input and output types are mismatched");
1739
1740 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001741}
1742
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001743bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1744 const TensorInfo& input1,
1745 const TensorInfo& output,
1746 Optional<std::string&> reasonIfUnsupported) const
1747{
Sadik Armagan2999a022019-04-09 14:20:12 +01001748 bool supported = true;
1749
Matthew Jackson9bff1442019-09-12 09:08:23 +01001750 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001751 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001752 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001753 DataType::QAsymmU8,
1754 DataType::QSymmS16
Sadik Armagan2999a022019-04-09 14:20:12 +01001755 };
1756
1757 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1758 "Reference subtraction: input 0 is not a supported type.");
1759
1760 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1761 "Reference subtraction: input 1 is not a supported type.");
1762
1763 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1764 "Reference subtraction: output is not a supported type.");
1765
1766 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1767 "Reference subtraction: input 0 and Input 1 types are mismatched");
1768
1769 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1770 "Reference subtraction: input and output types are mismatched");
1771
1772 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1773 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1774
1775 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001776}
1777
Matteo Martincighab9e5252019-06-13 17:27:46 +01001778bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1779 const TensorInfo& alpha,
1780 const TensorInfo& output,
1781 Optional<std::string&> reasonIfUnsupported) const
1782{
1783 bool supported = true;
1784
Matthew Jackson9bff1442019-09-12 09:08:23 +01001785 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001786 {
1787 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001788 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001789 DataType::QAsymmU8,
1790 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01001791 };
1792
1793 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1794 "PReLU: input is not a supported type.");
1795
1796 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1797 "PReLU: alpha is not a supported type.");
1798
1799 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1800 "PReLU: output is not a supported type.");
1801
1802 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1803 "PReLU: input, alpha and output types are mismatched");
1804
1805 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1806 "PReLU: shapes are not suitable for implicit broadcast");
1807
1808 return supported;
1809}
1810
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001811bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1812 const TensorInfo& output,
1813 const TransposeConvolution2dDescriptor& descriptor,
1814 const TensorInfo& weights,
1815 const Optional<TensorInfo>& biases,
1816 Optional<std::string&> reasonIfUnsupported) const
1817{
Derek Lamberti901ea112019-12-10 22:07:09 +00001818 boost::ignore_unused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001819 bool supported = true;
1820
Matthew Jackson252df3a2019-09-11 09:19:18 +01001821 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001822 {
1823 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001824 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001825 DataType::QAsymmU8,
1826 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001827 };
1828
1829 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1830 "Reference TransposeConvolution2d: input is not a supported type.");
1831
1832 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1833 "Reference TransposeConvolution2d: output is not a supported type.");
1834
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001835 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1836 "Reference TransposeConvolution2d: input and output types mismatched.");
1837
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001838
1839 const DataType inputType = input.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +00001840 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001841 {
1842 std::array<DataType, 2> supportedWeightTypes =
1843 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00001844 DataType::QAsymmU8,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001845 DataType::QuantizedSymm8PerAxis
1846 };
1847
1848 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1849 "Reference TransposeConvolution2d: weights type not supported for "
1850 "quantized input.");
1851 }
1852 else
1853 {
1854 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1855 "Reference TransposeConvolution2d: weights is not a supported type.");
1856
1857 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1858 "Reference TransposeConvolution2d: input and weights types mismatched.");
1859 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001860
1861 if (biases.has_value())
1862 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001863 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001864 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001865 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001866 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001867 DataType::Signed32
1868 };
1869 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1870 "Reference TransposeConvolution2d: biases is not a supported type.");
1871 }
1872
1873 return supported;
1874}
1875
arovir011c7c81b2018-10-08 11:34:28 +01001876} // namespace armnn