blob: 05684dcbc091910db85d5b7431a34f50a9084b0f [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000011#include <armnn/BackendRegistry.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matteo Martincighe011d202019-11-28 11:35:47 +000013#include <armnnUtils/DataLayoutIndexed.hpp>
14
15#include <InternalTypes.hpp>
16#include <LayerSupportCommon.hpp>
17
Derek Lambertif674aa02019-08-01 15:56:25 +010018#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000019
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010021
Matteo Martincighe011d202019-11-28 11:35:47 +000022#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000023#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000024
Derek Lamberti50db4e82019-03-13 14:16:15 +000025#include <vector>
26#include <algorithm>
27#include <array>
28
telsoa014fcda012018-03-09 14:13:49 +000029using namespace boost;
30
31namespace armnn
32{
33
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010034namespace
35{
36
37template<typename Float32Func, typename Uint8Func, typename ... Params>
38bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
39 DataType dataType,
40 Float32Func floatFuncPtr,
41 Uint8Func uint8FuncPtr,
42 Params&&... params)
43{
44 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
45 dataType,
46 &FalseFunc<Params...>,
47 floatFuncPtr,
48 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000049 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000050 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010051 std::forward<Params>(params)...);
52}
53
54} // anonymous namespace
55
James Conroy4d1ff582019-06-10 17:06:39 +010056namespace
57{
58
59std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
60 unsigned int actual,
61 std::string& layerStr,
62 std::string& tensorName)
63{
64 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
65 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
66
67 return errorMsg;
68}
69
70} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000071
Sadik Armagan9199e582019-09-05 17:35:31 +010072bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
73 Optional<std::string&> reasonIfUnsupported) const
74{
75 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +010076 std::array<DataType,4> supportedTypes =
Sadik Armagan9199e582019-09-05 17:35:31 +010077 {
78 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +010079 DataType::Float16,
Sadik Armagan9199e582019-09-05 17:35:31 +010080 DataType::QuantisedAsymm8,
81 DataType::QuantisedSymm16
82 };
83
84 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
85 "Reference abs: input type not supported");
86
87 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
88 "Reference abs: output type not supported");
89
90 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
91 "Reference abs: input and output types not matching");
92
93 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
94 "Reference abs: input and output shapes have different number of total elements");
95
96 return supported;
97}
98
arovir011c7c81b2018-10-08 11:34:28 +010099bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
100 const TensorInfo& output,
101 const ActivationDescriptor& descriptor,
102 Optional<std::string&> reasonIfUnsupported) const
103{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000104 bool supported = true;
105
106 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100107 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000108 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100109 DataType::Float16,
Teresa Charlin18515e22019-04-24 10:17:46 +0100110 DataType::QuantisedAsymm8,
111 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000112 };
113
114 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
115 "Reference activation: input type not supported.");
116
117 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
118 "Reference activation: output type not supported.");
119
120 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
121 "Reference activation: input and output types mismatched.");
122
123 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
124 "Reference activation: input and output shapes are of different rank.");
125
126
127 struct ActivationFunctionSupported : public Rule
128 {
129 ActivationFunctionSupported(const ActivationDescriptor& desc)
130 {
131 switch(desc.m_Function)
132 {
133 case ActivationFunction::Abs:
134 case ActivationFunction::BoundedReLu:
135 case ActivationFunction::LeakyReLu:
136 case ActivationFunction::Linear:
137 case ActivationFunction::ReLu:
138 case ActivationFunction::Sigmoid:
139 case ActivationFunction::SoftReLu:
140 case ActivationFunction::Sqrt:
141 case ActivationFunction::Square:
142 case ActivationFunction::TanH:
143 {
144 m_Res = true;
145 break;
146 }
147 default:
148 {
149 m_Res = false;
150 break;
151 }
152 }
153 }
154 };
155
156 // Function is supported
157 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
158 "Reference activation: function not supported.");
159
160 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100161}
162
163bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
164 const TensorInfo& input1,
165 const TensorInfo& output,
166 Optional<std::string&> reasonIfUnsupported) const
167{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000168 bool supported = true;
169
Matthew Jackson252df3a2019-09-11 09:19:18 +0100170 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000171 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100172 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100173 DataType::QuantisedAsymm8,
174 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000175 };
176
177 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
178 "Reference addition: input 0 is not a supported type.");
179
180 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
181 "Reference addition: input 1 is not a supported type.");
182
183 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
184 "Reference addition: output is not a supported type.");
185
186 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
187 "Reference addition: input 0 and Input 1 types are mismatched");
188
189 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
190 "Reference addition: input and output types are mismatched");
191
192 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
193 "Reference addition: shapes are not suitable for implicit broadcast.");
194
195 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100196}
197
Nikhil Raj68c2c902019-09-19 11:21:11 +0100198bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
199 const armnn::ArgMinMaxDescriptor &descriptor,
200 armnn::Optional<std::string &> reasonIfUnsupported) const
201{
202 ignore_unused(descriptor);
203
Francis Murtagh1939df52019-11-13 15:21:09 +0000204 std::array<DataType, 4> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100205 {
206 DataType::Float32,
207 DataType::QuantisedAsymm8,
Francis Murtagh1939df52019-11-13 15:21:09 +0000208 DataType::QuantisedSymm16,
209 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100210 };
211
212 bool supported = true;
213
214 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
215 "Reference ArgMinMax: input is not a supported type.");
216 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
217 "Reference ArgMinMax: output type not supported");
218
219 return supported;
220}
221
arovir011c7c81b2018-10-08 11:34:28 +0100222bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
223 const TensorInfo& output,
224 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100225 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100226 const TensorInfo& beta,
227 const TensorInfo& gamma,
228 const BatchNormalizationDescriptor& descriptor,
229 Optional<std::string&> reasonIfUnsupported) const
230{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100231 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100232
Matthew Jackson9bff1442019-09-12 09:08:23 +0100233 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100234 {
235 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100236 DataType::Float16,
Matteo Martincighf5507132019-06-04 10:59:47 +0100237 DataType::QuantisedAsymm8,
238 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100239 };
240
241 bool supported = true;
242
243 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
244 "Reference batch normalization: input is not a supported type.");
245
246 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
247 "Reference batch normalization: output is not a supported type.");
248
249 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
250 "Reference batch normalization: input and output types are mismatched");
251
252 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
253 "Reference batch normalization: mean is not a supported type.");
254
255 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
256 "Reference batch normalization: variance is not a supported type.");
257
258 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
259 "Reference batch normalization: beta is not a supported type.");
260
261 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
262 "Reference batch normalization: gamma is not a supported type.");
263
264 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100265}
266
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000267bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
268 const TensorInfo& output,
269 const BatchToSpaceNdDescriptor& descriptor,
270 Optional<std::string&> reasonIfUnsupported) const
271{
272 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100273
274 bool supported = true;
275
276 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
277 std::string inputTensorStr = "input";
278 std::string outputTensorStr = "output";
279
280 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100281 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100282 {
283 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100284 DataType::Float16,
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100285 DataType::QuantisedAsymm8,
286 DataType::QuantisedSymm16
287 };
288
289 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
290 "Reference BatchToSpaceNd: input type not supported.");
291
292 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
293 "Reference BatchToSpaceNd: output type not supported.");
294
295 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
296 "Reference BatchToSpaceNd: input and output types mismatched.");
297
298 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
299 reasonIfUnsupported,
300 CreateIncorrectDimensionsErrorMsg(4,
301 output.GetNumDimensions(),
302 batchToSpaceNdLayerStr,
303 outputTensorStr).data());
304
305 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
306 reasonIfUnsupported,
307 CreateIncorrectDimensionsErrorMsg(4,
308 input.GetNumDimensions(),
309 batchToSpaceNdLayerStr,
310 inputTensorStr).data());
311
312 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000313}
314
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100315bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
316 const TensorInfo& input1,
317 const TensorInfo& output,
318 const ComparisonDescriptor& descriptor,
319 Optional<std::string&> reasonIfUnsupported) const
320{
321 boost::ignore_unused(descriptor);
322
323 std::array<DataType, 4> supportedInputTypes =
324 {
325 DataType::Float32,
326 DataType::Float16,
327 DataType::QuantisedAsymm8,
328 DataType::QuantisedSymm16
329 };
330
331 bool supported = true;
332 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
333 "Reference comparison: input 0 is not a supported type");
334
335 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
336 "Reference comparison: input 0 and Input 1 types are mismatched");
337
338 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
339 "Reference comparison: output is not of type Boolean");
340
341 return supported;
342}
343
Jim Flynn906f9462019-05-10 13:55:21 +0100344bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
345 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100346 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100347 Optional<std::string&> reasonIfUnsupported) const
348{
Jim Flynne242f2d2019-05-22 14:24:13 +0100349 ignore_unused(descriptor);
350
351 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100352 std::array<DataType,4> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100353 {
354 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100355 DataType::Float16,
Jim Flynne242f2d2019-05-22 14:24:13 +0100356 DataType::QuantisedAsymm8,
357 DataType::QuantisedSymm16
358 };
359
360 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
361 "Reference concatenation: output type not supported");
362 for (const TensorInfo* input : inputs)
363 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100364 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100365 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
366 "Reference concatenation: input type not supported");
367
368 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
369 "Reference concatenation: input and output types mismatched.");
370 }
371
372 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100373}
374
arovir011c7c81b2018-10-08 11:34:28 +0100375bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
376 Optional<std::string&> reasonIfUnsupported) const
377{
Jim Flynne242f2d2019-05-22 14:24:13 +0100378 std::array<DataType,4> supportedTypes =
379 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100380 DataType::Float32,
381 DataType::Signed32,
382 DataType::QuantisedAsymm8,
383 DataType::QuantisedSymm16
384 };
385
386 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
387 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100388}
389
390bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
391 const TensorInfo& output,
392 Optional<std::string&> reasonIfUnsupported) const
393{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100394 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
395 input.GetDataType(),
396 &TrueFunc<>,
397 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000398 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000399 &FalseFuncI32<>,
400 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100401 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
402 output.GetDataType(),
403 &FalseOutputFuncF16<>,
404 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000405 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000406 &FalseFuncI32<>,
407 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100408}
409
410bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
411 const TensorInfo& output,
412 Optional<std::string&> reasonIfUnsupported) const
413{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100414 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
415 input.GetDataType(),
416 &FalseInputFuncF16<>,
417 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000418 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000419 &FalseFuncI32<>,
420 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100421 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
422 output.GetDataType(),
423 &TrueFunc<>,
424 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000425 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000426 &FalseFuncI32<>,
427 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100428}
429
430bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
431 const TensorInfo& output,
432 const Convolution2dDescriptor& descriptor,
433 const TensorInfo& weights,
434 const Optional<TensorInfo>& biases,
435 Optional<std::string&> reasonIfUnsupported) const
436{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100437 bool supported = true;
438
439 // Define supported types.
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000440 std::array<DataType,4> supportedTypes =
441 {
442 DataType::Float32,
443 DataType::Float16,
444 DataType::QuantisedAsymm8,
445 DataType::QuantisedSymm16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100446 };
447
448 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100449 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100450
451 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100452 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100453
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100454 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100455 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100456
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000457 const DataType inputType = input.GetDataType();
458 if (inputType == DataType::QuantisedAsymm8)
459 {
460 std::array<DataType, 2> supportedWeightTypes =
461 {
462 DataType::QuantisedAsymm8,
463 DataType::QuantizedSymm8PerAxis
464 };
465
466 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
467 "Reference convolution2d: weights type not supported for quantized input.");
468 }
469 else
470 {
471 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
472 "Reference convolution2d: weights is not a supported type.");
473
474 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
475 "Reference convolution2d: input and weights types mismatched.");
476 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100477
478 if (biases.has_value())
479 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000480 std::array<DataType,3> biasesSupportedTypes =
481 {
482 DataType::Float32,
483 DataType::Float16,
484 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100485 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000486
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100487 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100488 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100489 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100490 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100491
492 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100493}
494
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000495bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
496 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000497 Optional<std::string&> reasonIfUnsupported) const
498{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100499 bool supported = true;
500
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000501 std::array<DataType, 4> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100502 {
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000503 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100504 DataType::Float32,
505 DataType::QuantisedAsymm8,
506 DataType::QuantisedSymm16
507 };
508
509 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
510 "Reference debug: input type not supported");
511
512 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
513 "Reference debug: output type not supported");
514
515 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
516 "Reference debug: input and output types are mismatched");
517
518 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000519}
520
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100521bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
522 const TensorInfo& output,
523 const DepthToSpaceDescriptor& descriptor,
524 Optional<std::string&> reasonIfUnsupported) const
525{
526 ignore_unused(descriptor);
527 bool supported = true;
528
529 std::array<DataType,4> supportedTypes =
530 {
531 DataType::Float32,
532 DataType::Float16,
533 DataType::QuantisedAsymm8,
534 DataType::QuantisedSymm16
535 };
536
537 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
538 "Reference DepthToSpace: input type not supported");
539
540 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
541 "Reference DepthToSpace: output type not supported");
542
543 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
544 "Reference DepthToSpace: input and output types are mismatched");
545
546 return supported;
547}
548
arovir011c7c81b2018-10-08 11:34:28 +0100549bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
550 const TensorInfo& output,
551 const DepthwiseConvolution2dDescriptor& descriptor,
552 const TensorInfo& weights,
553 const Optional<TensorInfo>& biases,
554 Optional<std::string&> reasonIfUnsupported) const
555{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100556 bool supported = true;
557
558 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100559 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100560 {
561 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100562 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100563 DataType::QuantisedAsymm8,
564 DataType::QuantisedSymm16
565 };
566
567 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
568 "Reference DepthwiseConvolution2d: input is not a supported type.");
569
570 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
571 "Reference DepthwiseConvolution2d: output is not a supported type.");
572
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100573 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
574 "Reference DepthwiseConvolution2d: input and output types mismatched.");
575
Teresa Charlind8df0262019-11-11 12:28:15 +0000576 const DataType inputType = input.GetDataType();
577 if (inputType == DataType::QuantisedAsymm8)
578 {
579 std::array<DataType, 2> supportedWeightTypes =
580 {
581 DataType::QuantisedAsymm8,
582 DataType::QuantizedSymm8PerAxis
583 };
584
585 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
586 "Reference convolution2d: weights type not supported for quantized input.");
587 }
588 else
589 {
590 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
591 "Reference DepthwiseConvolution2d: weights is not a supported type.");
592
593 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
594 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
595 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100596
597 if (biases.has_value())
598 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100599 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100600 {
601 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100602 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100603 DataType::Signed32
604 };
605 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
606 "Reference DepthwiseConvolution2d: biases is not a supported type.");
607 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100608 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100609
610 return supported;
611
arovir011c7c81b2018-10-08 11:34:28 +0100612}
613
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000614bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
615 const TensorInfo& output,
616 Optional<std::string&> reasonIfUnsupported) const
617{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100618 bool supported = true;
619
620 std::array<DataType,2> supportedInputTypes = {
621 DataType::QuantisedAsymm8,
622 DataType::QuantisedSymm16
623 };
624
625 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
626 "Reference dequantize: input type not supported.");
627
Jan Eilersf7107932019-11-01 11:09:36 +0000628 std::array<DataType,2> supportedOutputTypes = {
629 DataType::Float32,
630 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100631 };
632
633 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
634 "Reference dequantize: output type not supported.");
635
636 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
637 "Reference dequantize: input and output shapes have different num total elements.");
638
639 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000640}
641
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000642bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
643 const armnn::TensorInfo& input1,
644 const armnn::DetectionPostProcessDescriptor& descriptor,
645 armnn::Optional<std::string&> reasonIfUnsupported) const
646{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100647 bool supported = true;
648
Mike Kelly4992c342019-08-14 11:33:11 +0100649 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100650 {
651 DataType::Float32,
652 DataType::QuantisedAsymm8,
653 DataType::QuantisedSymm16
654 };
655
656 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
657 "Reference DetectionPostProcess: input 0 is not a supported type.");
658
659 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
660 "Reference DetectionPostProcess: input 1 is not a supported type.");
661
662 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000663}
664
Pablo Tellof0bd6832019-04-26 17:58:13 +0100665bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
666 const TensorInfo& output,
667 const DepthwiseConvolution2dDescriptor& descriptor,
668 const TensorInfo& weights,
669 const Optional<TensorInfo>& biases,
670 Optional<std::string&> reasonIfUnsupported) const
671{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100672 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100673}
674
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100675bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100676 const TensorInfo& input1,
677 const TensorInfo& output,
678 Optional<std::string&> reasonIfUnsupported) const
679{
Sadik Armagan2999a022019-04-09 14:20:12 +0100680 bool supported = true;
681
Matthew Jackson9bff1442019-09-12 09:08:23 +0100682 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100683 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100684 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100685 DataType::QuantisedAsymm8,
686 DataType::QuantisedSymm16
687 };
688
689 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
690 "Reference division: input 0 is not a supported type.");
691
692 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
693 "Reference division: input 1 is not a supported type.");
694
695 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
696 "Reference division: output is not a supported type.");
697
698 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
699 "Reference division: input 0 and Input 1 types are mismatched");
700
701 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
702 "Reference division: input and output types are mismatched");
703
704 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
705 "Reference division: shapes are not suitable for implicit broadcast.");
706
707 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100708}
709
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000710bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
711 const TensorInfo& input1,
712 const TensorInfo& output,
713 Optional<std::string&> reasonIfUnsupported) const
714{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100715 return IsComparisonSupported(input0,
716 input1,
717 output,
718 ComparisonDescriptor(ComparisonOperation::Equal),
719 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000720}
721
arovir011c7c81b2018-10-08 11:34:28 +0100722bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
723 const FakeQuantizationDescriptor& descriptor,
724 Optional<std::string&> reasonIfUnsupported) const
725{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100726 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100727 bool supported = true;
728
729 std::array<DataType,1> supportedTypes =
730 {
731 DataType::Float32
732 };
733
734 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
735 "Reference fake quantization: input type not supported.");
736
737 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100738}
739
740bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
741 const TensorInfo& output,
742 Optional<std::string&> reasonIfUnsupported) const
743{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100744 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100745 bool supported = true;
746
Matthew Jackson9bff1442019-09-12 09:08:23 +0100747 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100748 {
James Conroyb40d7102019-06-04 12:32:09 +0100749 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100750 DataType::Float16,
James Conroyb40d7102019-06-04 12:32:09 +0100751 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100752 };
753
754 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
755 "Reference Floor: input type not supported.");
756
757 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
758 "Reference Floor: output type not supported.");
759
760 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100761}
762
763bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
764 const TensorInfo& output,
765 const TensorInfo& weights,
766 const TensorInfo& biases,
767 const FullyConnectedDescriptor& descriptor,
768 Optional<std::string&> reasonIfUnsupported) const
769{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100770 bool supported = true;
771
772 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100773 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100774 {
775 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100776 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100777 DataType::QuantisedAsymm8,
778 DataType::QuantisedSymm16
779 };
780
781 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
782 "Reference Fully Connected: input type not supported.");
783
784 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
785 "Reference Fully Connected: output type not supported.");
786
787 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
788 "Reference Fully Connected: input and output types mismatched.");
789
790 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
791 "Reference Fully Connected: weights type not supported.");
792
793 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
794 "Reference Fully Connected: input and weight types mismatched.");
795
796 if (descriptor.m_BiasEnabled)
797 {
798 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100799 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100800 supportedBiasTypes =
801 {
802 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100803 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100804 DataType::Signed32
805 };
806
807 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
808 "Reference Fully Connected: bias type not supported.");
809
810 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
811 "Reference Fully Connected: bias and weight types mismatch.");
812
813 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
814 "Reference Fully Connected: bias type inferred from weights is incompatible.");
815
816 }
817
818 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100819}
820
narpra014951d842019-01-18 16:53:53 +0000821bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
822 const armnn::TensorInfo& input1,
823 const armnn::TensorInfo& output,
824 armnn::Optional<std::string&> reasonIfUnsupported) const
825{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100826 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100827 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100828 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100829 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100830 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100831 DataType::QuantisedAsymm8,
832 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100833 };
834
835 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
836 "Reference Gather: input type not supported");
837
838 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
839 "Reference Gather: output type not supported");
840
841 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
842 "Reference Gather: indices (input1) type not supported");
843
844 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
845 "Reference Gather: input and output types not matching");
846
847 return supported;
narpra014951d842019-01-18 16:53:53 +0000848}
849
FrancisMurtagh878f0232018-12-19 10:56:15 +0000850bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
851 const TensorInfo& input1,
852 const TensorInfo& output,
853 Optional<std::string&> reasonIfUnsupported) const
854{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100855 return IsComparisonSupported(input0,
856 input1,
857 output,
858 ComparisonDescriptor(ComparisonOperation::Greater),
859 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000860}
861
arovir011c7c81b2018-10-08 11:34:28 +0100862bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
863 Optional<std::string&> reasonIfUnsupported) const
864{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100865 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100866}
867
Kevin May09ca49c2019-10-09 12:37:34 +0100868bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
869 const TensorInfo& output,
870 const InstanceNormalizationDescriptor& descriptor,
871 Optional<std::string&> reasonIfUnsupported) const
872{
873 ignore_unused(descriptor);
874 // Define supported types
875 std::array<DataType, 4> supportedTypes =
876 {
877 DataType::Float32,
878 DataType::Float16
879 };
880
881 bool supported = true;
882
883 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
884 "Reference Instance Normalization: input type not supported.");
885
886 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
887 "Reference Instance Normalization: output type not supported.");
888
889 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
890 "Reference Instance Normalization: input and output types mismatched.");
891
892 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
893 "Reference Instance Normalization: input and output shapes have different "
894 "num total elements.");
895
896 return supported;
897}
898
arovir011c7c81b2018-10-08 11:34:28 +0100899bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
900 const TensorInfo& output,
901 const L2NormalizationDescriptor& descriptor,
902 Optional<std::string&> reasonIfUnsupported) const
903{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100904 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100905 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100906 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100907 {
908 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100909 DataType::Float16,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100910 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100911 DataType::QuantisedSymm16
912 };
913
914 bool supported = true;
915
916 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
917 "Reference L2normalization: input type not supported.");
918
919 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
920 "Reference L2normalization: output type not supported.");
921
922 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
923 "Reference L2normalization: input and output types mismatched.");
924
925 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
926 "Reference L2normalization: input and output shapes have different "
927 "num total elements.");
928
929 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100930}
931
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100932bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
933 const TensorInfo& output,
934 const LogSoftmaxDescriptor& descriptor,
935 Optional<std::string&> reasonIfUnsupported) const
936{
937 ignore_unused(descriptor);
938
939 std::array<DataType, 2> supportedTypes =
940 {
941 DataType::Float32,
942 DataType::Float16
943 };
944
945 bool supported = true;
946 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
947 "Reference LogSoftmax: input type not supported");
948
949 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
950 "Reference LogSoftmax: output type not supported");
951
952 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
953 "Reference LogSoftmax: input and output types do not match");
954
955 return supported;
956}
957
arovir011c7c81b2018-10-08 11:34:28 +0100958bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
959 const TensorInfo& outputStateIn,
960 const TensorInfo& cellStateIn,
961 const TensorInfo& scratchBuffer,
962 const TensorInfo& outputStateOut,
963 const TensorInfo& cellStateOut,
964 const TensorInfo& output,
965 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100966 const LstmInputParamsInfo& paramsInfo,
967 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100968{
telsoa01c577f2c2018-08-31 09:22:23 +0100969 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100970 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100971
972 bool supported = true;
973
974 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100975 DataType::Float32,
976 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100977 };
978
Jan Eilersd01a83c2019-07-03 18:20:40 +0100979 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100980 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
981 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100982 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
983 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100984 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
985 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100986 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
987 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100988 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
989 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100990 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
991 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100992 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
993 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100994 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +0100995 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100996 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100997 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100998 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100999 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001000 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001001 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001002 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001003 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001004 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001005 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001006 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001007 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001008 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001009 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001010 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001011 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001012 "Reference Lstm: input and OutputGateBias types are mismatched");
1013 if (!descriptor.m_CifgEnabled)
1014 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001015 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001016 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001017 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001018 reasonIfUnsupported,
1019 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001020 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001021 "Reference Lstm: input and InputGateBias types are mismatched");
1022 if (descriptor.m_PeepholeEnabled)
1023 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001024 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001025 reasonIfUnsupported,
1026 "Reference Lstm: input and CellToInputWeights types are mismatched");
1027 }
1028 }
1029 if (descriptor.m_PeepholeEnabled)
1030 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001031 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001032 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001033 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001034 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1035 }
1036 if (descriptor.m_ProjectionEnabled)
1037 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001038 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001039 "Reference Lstm: input and mProjectionWeights types are mismatched");
1040 if (paramsInfo.m_ProjectionBias != nullptr)
1041 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001042 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001043 "Reference Lstm: input and ProjectionBias types are mismatched");
1044 }
1045 }
1046 if (descriptor.m_LayerNormEnabled)
1047 {
1048 if (!descriptor.m_CifgEnabled)
1049 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001050 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001051 reasonIfUnsupported,
1052 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1053 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001054 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001055 reasonIfUnsupported,
1056 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001057 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001058 reasonIfUnsupported,
1059 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001060 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001061 reasonIfUnsupported,
1062 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1063 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001064
1065 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001066}
1067
saoste012df12b32018-11-28 16:57:20 +00001068bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1069 const TensorInfo& input1,
1070 const TensorInfo& output,
1071 Optional<std::string&> reasonIfUnsupported) const
1072{
Sadik Armagan2999a022019-04-09 14:20:12 +01001073 bool supported = true;
1074
Matthew Jackson9bff1442019-09-12 09:08:23 +01001075 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001076 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001077 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001078 DataType::QuantisedAsymm8,
1079 DataType::QuantisedSymm16
1080 };
1081
1082 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1083 "Reference maximum: input 0 is not a supported type.");
1084
1085 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1086 "Reference maximum: input 1 is not a supported type.");
1087
1088 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1089 "Reference maximum: output is not a supported type.");
1090
1091 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1092 "Reference maximum: input 0 and Input 1 types are mismatched");
1093
1094 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1095 "Reference maximum: input and output types are mismatched");
1096
1097 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1098 "Reference maximum: shapes are not suitable for implicit broadcast.");
1099
1100 return supported;
saoste012df12b32018-11-28 16:57:20 +00001101}
1102
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001103bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1104 const TensorInfo& output,
1105 const MeanDescriptor& descriptor,
1106 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001107{
James Conroy4d1ff582019-06-10 17:06:39 +01001108 bool supported = true;
1109 std::string meanLayerStr = "Mean";
1110 std::string outputTensorStr = "output";
1111
Matthew Jackson252df3a2019-09-11 09:19:18 +01001112 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001113 {
1114 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001115 DataType::Float16,
James Conroyb80775f2019-06-11 11:25:30 +01001116 DataType::QuantisedAsymm8,
1117 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +01001118 };
1119
1120 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1121 "Reference Mean: input type not supported.");
1122
1123 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1124 "Reference Mean: input and output types are mismatched");
1125
1126 if (descriptor.m_KeepDims)
1127 {
1128 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1129 reasonIfUnsupported,
1130 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1131 output.GetNumDimensions(),
1132 meanLayerStr, outputTensorStr).data());
1133 }
1134 else if (descriptor.m_Axis.empty())
1135 {
1136 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1137 reasonIfUnsupported,
1138 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1139 meanLayerStr, outputTensorStr).data());
1140 }
1141 else
1142 {
1143 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1144
1145 if (outputDim > 0)
1146 {
1147 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1148 reasonIfUnsupported,
1149 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1150 meanLayerStr, outputTensorStr).data());
1151 }
1152 else
1153 {
1154 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1155 reasonIfUnsupported,
1156 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1157 meanLayerStr, outputTensorStr).data());
1158 }
1159 }
1160
1161 return supported;
narpra0132b90462018-09-13 11:07:48 +01001162}
1163
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001164bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001165 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001166 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001167 Optional<std::string&> reasonIfUnsupported) const
1168{
Jim Flynne242f2d2019-05-22 14:24:13 +01001169 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001170}
1171
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001172bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1173 const TensorInfo &output,
1174 Optional<std::string &> reasonIfUnsupported) const
1175{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001176 bool supported = true;
1177
1178 std::array<DataType,5> supportedTypes =
1179 {
1180 DataType::Float32,
1181 DataType::Float16,
1182 DataType::QuantisedAsymm8,
1183 DataType::QuantisedSymm16,
1184 DataType::Boolean
1185 };
1186
1187 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1188 "Reference MemCopy: input type not supported");
1189
1190 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1191 "Reference MemCopy: output type not supported");
1192
1193 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1194 "Reference MemCopy: input and output types are mismatched");
1195
1196 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001197}
1198
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001199bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1200 const TensorInfo& input1,
1201 const TensorInfo& output,
1202 Optional<std::string&> reasonIfUnsupported) const
1203{
Sadik Armagan2999a022019-04-09 14:20:12 +01001204 bool supported = true;
1205
Matthew Jackson9bff1442019-09-12 09:08:23 +01001206 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001207 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001208 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001209 DataType::QuantisedAsymm8,
1210 DataType::QuantisedSymm16
1211 };
1212
1213 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1214 "Reference minimum: input 0 is not a supported type.");
1215
1216 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1217 "Reference minimum: input 1 is not a supported type.");
1218
1219 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1220 "Reference minimum: output is not a supported type.");
1221
1222 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1223 "Reference minimum: input 0 and Input 1 types are mismatched");
1224
1225 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1226 "Reference minimum: input and output types are mismatched");
1227
1228 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1229 "Reference minimum: shapes are not suitable for implicit broadcast.");
1230
1231 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001232}
1233
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001234bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1235 const TensorInfo& input1,
1236 const TensorInfo& output,
1237 Optional<std::string&> reasonIfUnsupported) const
1238{
Sadik Armagan2999a022019-04-09 14:20:12 +01001239 bool supported = true;
1240
Matthew Jackson252df3a2019-09-11 09:19:18 +01001241 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001242 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001243 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001244 DataType::QuantisedAsymm8,
1245 DataType::QuantisedSymm16
1246 };
1247
1248 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1249 "Reference multiplication: input 0 is not a supported type.");
1250
1251 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1252 "Reference multiplication: input 1 is not a supported type.");
1253
1254 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1255 "Reference multiplication: output is not a supported type.");
1256
1257 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1258 "Reference multiplication: input 0 and Input 1 types are mismatched");
1259
1260 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1261 "Reference multiplication: input and output types are mismatched");
1262
1263 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1264 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1265
1266 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001267}
1268
1269bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1270 const TensorInfo& output,
1271 const NormalizationDescriptor& descriptor,
1272 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001273{
Nina Drozd661dfa72018-10-02 11:14:17 +01001274 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001275
1276 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001277 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001278 {
1279 DataType::Float16,
1280 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001281 DataType::QuantisedAsymm8,
1282 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001283 };
1284
1285 bool supported = true;
1286
1287 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1288 "Reference normalization: input type not supported.");
1289
1290 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1291 "Reference normalization: output type not supported.");
1292
1293 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1294 "Reference normalization: input and output shapes have different "
1295 "num total elements.");
1296
1297 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001298}
1299
1300bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1301 Optional<std::string&> reasonIfUnsupported) const
1302{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001303 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001304}
1305
1306bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1307 const TensorInfo& output,
1308 const PadDescriptor& descriptor,
1309 Optional<std::string&> reasonIfUnsupported) const
1310{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001311 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001312 bool supported = true;
1313
1314 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001315 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001316 {
1317 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001318 DataType::Float16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001319 DataType::QuantisedAsymm8,
1320 DataType::QuantisedSymm16
1321 };
1322
1323 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1324 "Reference pad: input is not a supported type.");
1325
1326 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1327 "Reference pad: output is not a supported type.");
1328
1329 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1330 "Reference pad: input and output types are mismatched.");
1331
1332 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001333}
1334
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001335bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1336 const TensorInfo& output,
1337 const PermuteDescriptor& descriptor,
1338 Optional<std::string&> reasonIfUnsupported) const
1339{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001340 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001341 bool supported = true;
1342
1343 // Define supported output and inputs types.
1344 std::array<DataType,3> supportedTypes =
1345 {
1346 DataType::Float32,
1347 DataType::QuantisedAsymm8,
1348 DataType::QuantisedSymm16
1349 };
1350
1351 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1352 "Reference permute: input is not a supported type.");
1353
1354 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1355 "Reference permute: output is not a supported type.");
1356
1357 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1358 "Reference permute: input and output types are mismatched.");
1359
1360 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001361}
1362
1363bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1364 const TensorInfo& output,
1365 const Pooling2dDescriptor& descriptor,
1366 Optional<std::string&> reasonIfUnsupported) const
1367{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001368 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001369 bool supported = true;
1370
1371 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001372 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001373 {
1374 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001375 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001376 DataType::QuantisedAsymm8,
1377 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001378 };
1379
1380 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1381 "Reference poolind2d: input is not a supported type.");
1382
1383 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1384 "Reference poolind2d: output is not a supported type.");
1385
1386 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1387 "Reference poolind2d: input and output types are mismatched.");
1388
1389 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001390}
1391
Derek Lamberti5f400d62019-03-25 15:41:58 +00001392bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1393 const TensorInfo& output,
1394 Optional<std::string&> reasonIfUnsupported) const
1395{
1396 bool supported = true;
1397
1398 // Define supported output types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001399 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001400 DataType::Float32,
1401 };
1402
1403 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1404 "Reference quantize: input type not supported.");
1405
1406 // Define supported output types.
1407 std::array<DataType,2> supportedOutputTypes = {
1408 DataType::QuantisedAsymm8,
1409 DataType::QuantisedSymm16
1410 };
1411 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1412 "Reference quantize: output type not supported.");
1413
1414 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1415 "Reference quantize: input and output shapes have different num total elements.");
1416
1417 return supported;
1418}
1419
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001420bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001421 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001422 Optional<std::string&> reasonIfUnsupported) const
1423{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001424 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001425 // Define supported output types.
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001426 std::array<DataType,5> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001427 {
1428 DataType::Float32,
1429 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001430 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001431 DataType::QuantisedAsymm8,
1432 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001433 };
1434 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1435 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001436}
1437
1438bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001439 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001440 Optional<std::string&> reasonIfUnsupported) const
1441{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001442 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001443 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001444 {
1445 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001446 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001447 DataType::QuantisedAsymm8,
1448 DataType::QuantisedSymm16
1449 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001450
1451 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1452 "Reference ResizeBilinear: input type not supported");
1453
1454 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1455 "Reference ResizeBilinear: output type not supported");
1456
1457 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1458 "Reference ResizeBilinear: input and output types not matching");
1459
1460 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001461}
1462
Teresa Charlin970f43b2019-07-01 13:51:07 +01001463bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1464 const TensorInfo& output,
1465 const ResizeDescriptor& descriptor,
1466 Optional<std::string&> reasonIfUnsupported) const
1467{
1468 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001469 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001470 {
1471 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001472 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001473 DataType::QuantisedAsymm8,
1474 DataType::QuantisedSymm16
1475 };
1476
1477 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1478 "Reference Resize: input type not supported");
1479
1480 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1481 "Reference Resize: output type not supported");
1482
1483 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1484 "Reference Resize: input and output types not matching");
1485
1486 return supported;
1487}
1488
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001489bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1490 const TensorInfo& output,
1491 Optional<std::string&> reasonIfUnsupported) const
1492{
nikraj010421e7f2019-06-14 09:40:34 +01001493 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001494 std::array<DataType,4> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001495 {
1496 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001497 DataType::Float16,
nikraj0124d73212019-06-14 14:20:40 +01001498 DataType::QuantisedAsymm8,
1499 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001500 };
1501
1502 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1503 "Reference rsqrt: input type not supported");
1504
1505 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1506 "Reference rsqrt: output type not supported");
1507
1508 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1509 "Reference rsqrt: input and output types not matching");
1510
1511 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1512 "Reference Rsqrt: input and output shapes have different number of total elements");
1513
1514 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001515}
1516
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001517bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1518 const TensorInfo& output,
1519 const SliceDescriptor& descriptor,
1520 Optional<std::string&> reasonIfUnsupported) const
1521{
1522 ignore_unused(descriptor);
1523 bool supported = true;
1524
1525 std::array<DataType, 3> supportedTypes =
1526 {
1527 DataType::Float32,
1528 DataType::QuantisedAsymm8,
1529 DataType::QuantisedSymm16
1530 };
1531
1532 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1533 "Reference Slice: input type not supported");
1534
1535 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1536 "Reference Slice: output type not supported");
1537
1538 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1539 "Reference Slice: input and output types are mismatched");
1540
1541 return supported;
1542}
1543
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001544bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1545 const TensorInfo& output,
1546 const SoftmaxDescriptor& descriptor,
1547 Optional<std::string&> reasonIfUnsupported) const
1548{
1549 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001550 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001551 std::array<DataType,4> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001552 {
1553 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001554 DataType::Float16,
nikraj01248683f2019-05-29 16:46:50 +01001555 DataType::QuantisedAsymm8,
1556 DataType::QuantisedSymm16
1557 };
1558
1559 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001560 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001561
1562 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001563 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001564
1565 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001566 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001567
1568 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001569}
1570
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001571bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1572 const TensorInfo& output,
1573 const SpaceToBatchNdDescriptor& descriptor,
1574 Optional<std::string&> reasonIfUnsupported) const
1575{
1576 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001577 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001578 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001579 {
1580 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001581 DataType::Float16,
nikraj01120522a2019-05-31 11:33:07 +01001582 DataType::QuantisedAsymm8,
1583 DataType::QuantisedSymm16
1584 };
1585
1586 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1587 "Reference SpaceToBatchNd: input type not supported");
1588
1589 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1590 "Reference SpaceToBatchNd: output type not supported");
1591
1592 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1593 "Reference SpaceToBatchNd: input and output types are mismatched");
1594
1595 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001596}
1597
Keith Davisa57eccb2019-06-14 17:33:22 +01001598bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001599 const TensorInfo& output,
1600 const SpaceToDepthDescriptor& descriptor,
1601 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001602{
1603
1604 ignore_unused(descriptor);
1605 bool supported = true;
1606
Matthew Jackson9bff1442019-09-12 09:08:23 +01001607 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001608 {
1609 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001610 DataType::Float16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001611 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001612 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001613 };
1614
1615 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1616 "Reference SpaceToDepth: input type not supported");
1617
1618 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1619 "Reference SpaceToDepth: output type not supported");
1620
1621 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1622 "Reference SpaceToDepth: input and output types are mismatched");
1623
1624 return supported;
1625}
1626
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001627bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1628 const ViewsDescriptor& descriptor,
1629 Optional<std::string&> reasonIfUnsupported) const
1630{
1631 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001632 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001633 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001634 {
1635 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001636 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001637 DataType::QuantisedAsymm8,
1638 DataType::QuantisedSymm16
1639 };
1640
1641 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1642 "Reference splitter: input type not supported");
1643
1644 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001645}
1646
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001647bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1648 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1649 const ViewsDescriptor& descriptor,
1650 Optional<std::string&> reasonIfUnsupported) const
1651{
1652 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001653 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001654 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001655 {
1656 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001657 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001658 DataType::QuantisedAsymm8,
1659 DataType::QuantisedSymm16
1660 };
1661
1662 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1663 "Reference splitter: output type not supported");
1664 for (const TensorInfo output : outputs)
1665 {
1666 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1667 "Reference splitter: input type not supported");
1668
1669 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1670 "Reference splitter: input and output types mismatched.");
1671 }
1672
1673 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001674}
1675
Matthew Jackson81e601c2019-07-11 12:07:09 +01001676bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1677 const TensorInfo& output,
1678 const StackDescriptor& descriptor,
1679 Optional<std::string&> reasonIfUnsupported) const
1680{
1681 ignore_unused(descriptor);
1682
1683 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001684 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001685 {
1686 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001687 DataType::Float16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001688 DataType::QuantisedAsymm8,
1689 DataType::QuantisedSymm16
1690 };
1691
1692 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1693 "Reference stack: output type not supported");
1694 for (const TensorInfo* input : inputs)
1695 {
1696 BOOST_ASSERT(input != nullptr);
1697 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1698 "Reference stack: input type not supported");
1699
1700 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1701 "Reference stack: input and output types mismatched.");
1702 }
1703
1704 return supported;
1705}
1706
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001707bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1708 const TensorInfo& output,
1709 const StridedSliceDescriptor& descriptor,
1710 Optional<std::string&> reasonIfUnsupported) const
1711{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001712 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001713 bool supported = true;
1714
1715 std::array<DataType,3> supportedTypes =
1716 {
1717 DataType::Float32,
1718 DataType::QuantisedAsymm8,
1719 DataType::QuantisedSymm16
1720 };
1721
1722 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1723 "Reference StridedSlice: input type not supported");
1724
1725 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1726 "Reference StridedSlice: output type not supported");
1727
1728 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1729 "Reference StridedSlice: input and output types are mismatched");
1730
1731 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001732}
1733
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001734bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1735 const TensorInfo& input1,
1736 const TensorInfo& output,
1737 Optional<std::string&> reasonIfUnsupported) const
1738{
Sadik Armagan2999a022019-04-09 14:20:12 +01001739 bool supported = true;
1740
Matthew Jackson9bff1442019-09-12 09:08:23 +01001741 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001742 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001743 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001744 DataType::QuantisedAsymm8,
1745 DataType::QuantisedSymm16
1746 };
1747
1748 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1749 "Reference subtraction: input 0 is not a supported type.");
1750
1751 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1752 "Reference subtraction: input 1 is not a supported type.");
1753
1754 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1755 "Reference subtraction: output is not a supported type.");
1756
1757 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1758 "Reference subtraction: input 0 and Input 1 types are mismatched");
1759
1760 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1761 "Reference subtraction: input and output types are mismatched");
1762
1763 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1764 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1765
1766 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001767}
1768
Matteo Martincighab9e5252019-06-13 17:27:46 +01001769bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1770 const TensorInfo& alpha,
1771 const TensorInfo& output,
1772 Optional<std::string&> reasonIfUnsupported) const
1773{
1774 bool supported = true;
1775
Matthew Jackson9bff1442019-09-12 09:08:23 +01001776 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001777 {
1778 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001779 DataType::Float16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01001780 DataType::QuantisedAsymm8,
1781 DataType::QuantisedSymm16
1782 };
1783
1784 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1785 "PReLU: input is not a supported type.");
1786
1787 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1788 "PReLU: alpha is not a supported type.");
1789
1790 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1791 "PReLU: output is not a supported type.");
1792
1793 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1794 "PReLU: input, alpha and output types are mismatched");
1795
1796 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1797 "PReLU: shapes are not suitable for implicit broadcast");
1798
1799 return supported;
1800}
1801
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001802bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1803 const TensorInfo& output,
1804 const TransposeConvolution2dDescriptor& descriptor,
1805 const TensorInfo& weights,
1806 const Optional<TensorInfo>& biases,
1807 Optional<std::string&> reasonIfUnsupported) const
1808{
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001809 bool supported = true;
1810
Matthew Jackson252df3a2019-09-11 09:19:18 +01001811 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001812 {
1813 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001814 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001815 DataType::QuantisedAsymm8,
1816 DataType::QuantisedSymm16
1817 };
1818
1819 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1820 "Reference TransposeConvolution2d: input is not a supported type.");
1821
1822 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1823 "Reference TransposeConvolution2d: output is not a supported type.");
1824
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001825 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1826 "Reference TransposeConvolution2d: input and output types mismatched.");
1827
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001828
1829 const DataType inputType = input.GetDataType();
1830 if (inputType == DataType::QuantisedAsymm8)
1831 {
1832 std::array<DataType, 2> supportedWeightTypes =
1833 {
1834 DataType::QuantisedAsymm8,
1835 DataType::QuantizedSymm8PerAxis
1836 };
1837
1838 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1839 "Reference TransposeConvolution2d: weights type not supported for "
1840 "quantized input.");
1841 }
1842 else
1843 {
1844 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1845 "Reference TransposeConvolution2d: weights is not a supported type.");
1846
1847 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1848 "Reference TransposeConvolution2d: input and weights types mismatched.");
1849 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001850
1851 if (biases.has_value())
1852 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001853 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001854 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001855 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001856 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001857 DataType::Signed32
1858 };
1859 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1860 "Reference TransposeConvolution2d: biases is not a supported type.");
1861 }
1862
1863 return supported;
1864}
1865
arovir011c7c81b2018-10-08 11:34:28 +01001866} // namespace armnn