blob: 56ca437b211b2c5c59328ad4439fddf29ba2d566 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01009#include <DataLayoutIndexed.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +010012
telsoa014fcda012018-03-09 14:13:49 +000013#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000014#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
David Beck111b5d92018-11-12 14:59:37 +000016#include <backendsCommon/BackendRegistry.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010017#include <backendsCommon/LayerSupportRules.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010018#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010019
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
Derek Lamberti50db4e82019-03-13 14:16:15 +000022#include <vector>
23#include <algorithm>
24#include <array>
25
telsoa014fcda012018-03-09 14:13:49 +000026using namespace boost;
27
28namespace armnn
29{
30
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010031namespace
32{
33
34template<typename Float32Func, typename Uint8Func, typename ... Params>
35bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
36 DataType dataType,
37 Float32Func floatFuncPtr,
38 Uint8Func uint8FuncPtr,
39 Params&&... params)
40{
41 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
42 dataType,
43 &FalseFunc<Params...>,
44 floatFuncPtr,
45 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000046 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000047 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010048 std::forward<Params>(params)...);
49}
50
51} // anonymous namespace
52
James Conroy4d1ff582019-06-10 17:06:39 +010053namespace
54{
55
56std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
57 unsigned int actual,
58 std::string& layerStr,
59 std::string& tensorName)
60{
61 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
62 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
63
64 return errorMsg;
65}
66
67} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000068
arovir011c7c81b2018-10-08 11:34:28 +010069bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
70 const TensorInfo& output,
71 const ActivationDescriptor& descriptor,
72 Optional<std::string&> reasonIfUnsupported) const
73{
Derek Lamberti50db4e82019-03-13 14:16:15 +000074 bool supported = true;
75
76 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +010077 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +000078 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +010079 DataType::QuantisedAsymm8,
80 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +000081 };
82
83 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
84 "Reference activation: input type not supported.");
85
86 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
87 "Reference activation: output type not supported.");
88
89 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
90 "Reference activation: input and output types mismatched.");
91
92 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
93 "Reference activation: input and output shapes are of different rank.");
94
95
96 struct ActivationFunctionSupported : public Rule
97 {
98 ActivationFunctionSupported(const ActivationDescriptor& desc)
99 {
100 switch(desc.m_Function)
101 {
102 case ActivationFunction::Abs:
103 case ActivationFunction::BoundedReLu:
104 case ActivationFunction::LeakyReLu:
105 case ActivationFunction::Linear:
106 case ActivationFunction::ReLu:
107 case ActivationFunction::Sigmoid:
108 case ActivationFunction::SoftReLu:
109 case ActivationFunction::Sqrt:
110 case ActivationFunction::Square:
111 case ActivationFunction::TanH:
112 {
113 m_Res = true;
114 break;
115 }
116 default:
117 {
118 m_Res = false;
119 break;
120 }
121 }
122 }
123 };
124
125 // Function is supported
126 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
127 "Reference activation: function not supported.");
128
129 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100130}
131
132bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
133 const TensorInfo& input1,
134 const TensorInfo& output,
135 Optional<std::string&> reasonIfUnsupported) const
136{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000137 bool supported = true;
138
Sadik Armagan2999a022019-04-09 14:20:12 +0100139 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000140 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100141 DataType::QuantisedAsymm8,
142 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000143 };
144
145 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
146 "Reference addition: input 0 is not a supported type.");
147
148 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
149 "Reference addition: input 1 is not a supported type.");
150
151 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
152 "Reference addition: output is not a supported type.");
153
154 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
155 "Reference addition: input 0 and Input 1 types are mismatched");
156
157 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
158 "Reference addition: input and output types are mismatched");
159
160 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
161 "Reference addition: shapes are not suitable for implicit broadcast.");
162
163 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100164}
165
166bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
167 const TensorInfo& output,
168 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100169 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100170 const TensorInfo& beta,
171 const TensorInfo& gamma,
172 const BatchNormalizationDescriptor& descriptor,
173 Optional<std::string&> reasonIfUnsupported) const
174{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100175 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100176
Matteo Martincighf5507132019-06-04 10:59:47 +0100177 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100178 {
179 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100180 DataType::QuantisedAsymm8,
181 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100182 };
183
184 bool supported = true;
185
186 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
187 "Reference batch normalization: input is not a supported type.");
188
189 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
190 "Reference batch normalization: output is not a supported type.");
191
192 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
193 "Reference batch normalization: input and output types are mismatched");
194
195 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
196 "Reference batch normalization: mean is not a supported type.");
197
198 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
199 "Reference batch normalization: variance is not a supported type.");
200
201 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
202 "Reference batch normalization: beta is not a supported type.");
203
204 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
205 "Reference batch normalization: gamma is not a supported type.");
206
207 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100208}
209
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000210bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
211 const TensorInfo& output,
212 const BatchToSpaceNdDescriptor& descriptor,
213 Optional<std::string&> reasonIfUnsupported) const
214{
215 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100216
217 bool supported = true;
218
219 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
220 std::string inputTensorStr = "input";
221 std::string outputTensorStr = "output";
222
223 // Define supported types.
224 std::array<DataType,3> supportedTypes =
225 {
226 DataType::Float32,
227 DataType::QuantisedAsymm8,
228 DataType::QuantisedSymm16
229 };
230
231 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
232 "Reference BatchToSpaceNd: input type not supported.");
233
234 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
235 "Reference BatchToSpaceNd: output type not supported.");
236
237 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
238 "Reference BatchToSpaceNd: input and output types mismatched.");
239
240 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
241 reasonIfUnsupported,
242 CreateIncorrectDimensionsErrorMsg(4,
243 output.GetNumDimensions(),
244 batchToSpaceNdLayerStr,
245 outputTensorStr).data());
246
247 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
248 reasonIfUnsupported,
249 CreateIncorrectDimensionsErrorMsg(4,
250 input.GetNumDimensions(),
251 batchToSpaceNdLayerStr,
252 inputTensorStr).data());
253
254 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000255}
256
Jim Flynn906f9462019-05-10 13:55:21 +0100257bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
258 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100259 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100260 Optional<std::string&> reasonIfUnsupported) const
261{
Jim Flynne242f2d2019-05-22 14:24:13 +0100262 ignore_unused(descriptor);
263
264 bool supported = true;
265 std::array<DataType,3> supportedTypes =
266 {
267 DataType::Float32,
268 DataType::QuantisedAsymm8,
269 DataType::QuantisedSymm16
270 };
271
272 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
273 "Reference concatenation: output type not supported");
274 for (const TensorInfo* input : inputs)
275 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100276 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100277 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
278 "Reference concatenation: input type not supported");
279
280 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
281 "Reference concatenation: input and output types mismatched.");
282 }
283
284 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100285}
286
arovir011c7c81b2018-10-08 11:34:28 +0100287bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
288 Optional<std::string&> reasonIfUnsupported) const
289{
Jim Flynne242f2d2019-05-22 14:24:13 +0100290 std::array<DataType,4> supportedTypes =
291 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100292 DataType::Float32,
293 DataType::Signed32,
294 DataType::QuantisedAsymm8,
295 DataType::QuantisedSymm16
296 };
297
298 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
299 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100300}
301
302bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
303 const TensorInfo& output,
304 Optional<std::string&> reasonIfUnsupported) const
305{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100306 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
307 input.GetDataType(),
308 &TrueFunc<>,
309 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000310 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000311 &FalseFuncI32<>,
312 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100313 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
314 output.GetDataType(),
315 &FalseOutputFuncF16<>,
316 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000317 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000318 &FalseFuncI32<>,
319 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100320}
321
322bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
323 const TensorInfo& output,
324 Optional<std::string&> reasonIfUnsupported) const
325{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100326 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
327 input.GetDataType(),
328 &FalseInputFuncF16<>,
329 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000330 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000331 &FalseFuncI32<>,
332 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100333 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
334 output.GetDataType(),
335 &TrueFunc<>,
336 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000337 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000338 &FalseFuncI32<>,
339 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100340}
341
342bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
343 const TensorInfo& output,
344 const Convolution2dDescriptor& descriptor,
345 const TensorInfo& weights,
346 const Optional<TensorInfo>& biases,
347 Optional<std::string&> reasonIfUnsupported) const
348{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100349 bool supported = true;
350
351 // Define supported types.
352 std::array<DataType,3> supportedTypes = {
353 DataType::Float32,
354 DataType::QuantisedAsymm8,
355 DataType::QuantisedSymm16
356 };
357
358 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100359 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100360
361 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100362 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100363
364 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100365 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100366
367 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100368 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100369
370 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100371 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100372
373 if (biases.has_value())
374 {
375 std::array<DataType,3> biasesSupportedTypes = {
376 DataType::Float32,
377 DataType::Signed32
378 };
379 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100380 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100381 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100382 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100383
384 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100385}
386
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000387bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
388 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000389 Optional<std::string&> reasonIfUnsupported) const
390{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100391 bool supported = true;
392
393 std::array<DataType,3> supportedTypes =
394 {
395 DataType::Float32,
396 DataType::QuantisedAsymm8,
397 DataType::QuantisedSymm16
398 };
399
400 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
401 "Reference debug: input type not supported");
402
403 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
404 "Reference debug: output type not supported");
405
406 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
407 "Reference debug: input and output types are mismatched");
408
409 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000410}
411
arovir011c7c81b2018-10-08 11:34:28 +0100412bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
413 const TensorInfo& output,
414 const DepthwiseConvolution2dDescriptor& descriptor,
415 const TensorInfo& weights,
416 const Optional<TensorInfo>& biases,
417 Optional<std::string&> reasonIfUnsupported) const
418{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100419 bool supported = true;
420
421 // Define supported types.
422 std::array<DataType,3> supportedTypes =
423 {
424 DataType::Float32,
425 DataType::QuantisedAsymm8,
426 DataType::QuantisedSymm16
427 };
428
429 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
430 "Reference DepthwiseConvolution2d: input is not a supported type.");
431
432 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
433 "Reference DepthwiseConvolution2d: output is not a supported type.");
434
435 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
436 "Reference DepthwiseConvolution2d: weights is not a supported type.");
437
438 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
439 "Reference DepthwiseConvolution2d: input and output types mismatched.");
440
441 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
442 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
443
444 if (biases.has_value())
445 {
446 std::array<DataType,2> biasesSupportedTypes =
447 {
448 DataType::Float32,
449 DataType::Signed32
450 };
451 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
452 "Reference DepthwiseConvolution2d: biases is not a supported type.");
453 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100454 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100455
456 return supported;
457
arovir011c7c81b2018-10-08 11:34:28 +0100458}
459
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000460bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
461 const TensorInfo& output,
462 Optional<std::string&> reasonIfUnsupported) const
463{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100464 bool supported = true;
465
466 std::array<DataType,2> supportedInputTypes = {
467 DataType::QuantisedAsymm8,
468 DataType::QuantisedSymm16
469 };
470
471 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
472 "Reference dequantize: input type not supported.");
473
474 std::array<DataType,2> supportedOutputTypes = {
475 DataType::Float32,
476 };
477
478 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
479 "Reference dequantize: output type not supported.");
480
481 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
482 "Reference dequantize: input and output shapes have different num total elements.");
483
484 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000485}
486
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000487bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
488 const armnn::TensorInfo& input1,
489 const armnn::DetectionPostProcessDescriptor& descriptor,
490 armnn::Optional<std::string&> reasonIfUnsupported) const
491{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100492 bool supported = true;
493
494 std::vector<DataType> supportedInputTypes =
495 {
496 DataType::Float32,
497 DataType::QuantisedAsymm8,
498 DataType::QuantisedSymm16
499 };
500
501 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
502 "Reference DetectionPostProcess: input 0 is not a supported type.");
503
504 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
505 "Reference DetectionPostProcess: input 1 is not a supported type.");
506
507 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000508}
509
Pablo Tellof0bd6832019-04-26 17:58:13 +0100510bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
511 const TensorInfo& output,
512 const DepthwiseConvolution2dDescriptor& descriptor,
513 const TensorInfo& weights,
514 const Optional<TensorInfo>& biases,
515 Optional<std::string&> reasonIfUnsupported) const
516{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100517 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100518}
519
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100520bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100521 const TensorInfo& input1,
522 const TensorInfo& output,
523 Optional<std::string&> reasonIfUnsupported) const
524{
Sadik Armagan2999a022019-04-09 14:20:12 +0100525 bool supported = true;
526
527 std::array<DataType,3> supportedTypes = {
528 DataType::Float32,
529 DataType::QuantisedAsymm8,
530 DataType::QuantisedSymm16
531 };
532
533 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
534 "Reference division: input 0 is not a supported type.");
535
536 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
537 "Reference division: input 1 is not a supported type.");
538
539 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
540 "Reference division: output is not a supported type.");
541
542 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
543 "Reference division: input 0 and Input 1 types are mismatched");
544
545 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
546 "Reference division: input and output types are mismatched");
547
548 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
549 "Reference division: shapes are not suitable for implicit broadcast.");
550
551 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100552}
553
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000554bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
555 const TensorInfo& input1,
556 const TensorInfo& output,
557 Optional<std::string&> reasonIfUnsupported) const
558{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100559 bool supported = true;
560
561 std::array<DataType,3> supportedTypes =
562 {
563 DataType::Float32,
564 DataType::QuantisedAsymm8,
565 DataType::QuantisedSymm16
566 };
567
568 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
569 "Reference equal: input 0 is not a supported type.");
570
571 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
572 "Reference equal: input 1 is not a supported type.");
573
574 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
575 "Reference equal: input 0 and Input 1 types are mismatched");
576
577 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
578 "Reference equal: shapes are not suitable for implicit broadcast.");
579
580 return supported;
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000581}
582
arovir011c7c81b2018-10-08 11:34:28 +0100583bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
584 const FakeQuantizationDescriptor& descriptor,
585 Optional<std::string&> reasonIfUnsupported) const
586{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100587 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100588 bool supported = true;
589
590 std::array<DataType,1> supportedTypes =
591 {
592 DataType::Float32
593 };
594
595 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
596 "Reference fake quantization: input type not supported.");
597
598 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100599}
600
601bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
602 const TensorInfo& output,
603 Optional<std::string&> reasonIfUnsupported) const
604{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100605 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100606 bool supported = true;
607
James Conroyb40d7102019-06-04 12:32:09 +0100608 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100609 {
James Conroyb40d7102019-06-04 12:32:09 +0100610 DataType::Float32,
611 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100612 };
613
614 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
615 "Reference Floor: input type not supported.");
616
617 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
618 "Reference Floor: output type not supported.");
619
620 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100621}
622
623bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
624 const TensorInfo& output,
625 const TensorInfo& weights,
626 const TensorInfo& biases,
627 const FullyConnectedDescriptor& descriptor,
628 Optional<std::string&> reasonIfUnsupported) const
629{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100630 bool supported = true;
631
632 // Define supported types.
633 std::array<DataType,3> supportedTypes =
634 {
635 DataType::Float32,
636 DataType::QuantisedAsymm8,
637 DataType::QuantisedSymm16
638 };
639
640 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
641 "Reference Fully Connected: input type not supported.");
642
643 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
644 "Reference Fully Connected: output type not supported.");
645
646 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
647 "Reference Fully Connected: input and output types mismatched.");
648
649 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
650 "Reference Fully Connected: weights type not supported.");
651
652 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
653 "Reference Fully Connected: input and weight types mismatched.");
654
655 if (descriptor.m_BiasEnabled)
656 {
657 // Defined supported types for bias
658 std::array<DataType, 2>
659 supportedBiasTypes =
660 {
661 DataType::Float32,
662 DataType::Signed32
663 };
664
665 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
666 "Reference Fully Connected: bias type not supported.");
667
668 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
669 "Reference Fully Connected: bias and weight types mismatch.");
670
671 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
672 "Reference Fully Connected: bias type inferred from weights is incompatible.");
673
674 }
675
676 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100677}
678
narpra014951d842019-01-18 16:53:53 +0000679bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
680 const armnn::TensorInfo& input1,
681 const armnn::TensorInfo& output,
682 armnn::Optional<std::string&> reasonIfUnsupported) const
683{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100684 bool supported = true;
685 std::array<DataType,3> supportedTypes =
686 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100687 DataType::Float32,
688 DataType::QuantisedAsymm8,
689 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100690 };
691
692 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
693 "Reference Gather: input type not supported");
694
695 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
696 "Reference Gather: output type not supported");
697
698 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
699 "Reference Gather: indices (input1) type not supported");
700
701 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
702 "Reference Gather: input and output types not matching");
703
704 return supported;
narpra014951d842019-01-18 16:53:53 +0000705}
706
FrancisMurtagh878f0232018-12-19 10:56:15 +0000707bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
708 const TensorInfo& input1,
709 const TensorInfo& output,
710 Optional<std::string&> reasonIfUnsupported) const
711{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100712 bool supported = true;
713
714 std::array<DataType,3> supportedTypes =
715 {
716 DataType::Float32,
717 DataType::QuantisedAsymm8,
718 DataType::QuantisedSymm16
719 };
720
721 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
722 "Reference greater: input 0 is not a supported type.");
723
724 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
725 "Reference greater: input 1 is not a supported type.");
726
727 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
728 "Reference greater: input 0 and Input 1 types are mismatched");
729
730 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
731 "Reference greater: shapes are not suitable for implicit broadcast.");
732
733 return supported;
FrancisMurtagh878f0232018-12-19 10:56:15 +0000734}
735
arovir011c7c81b2018-10-08 11:34:28 +0100736bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
737 Optional<std::string&> reasonIfUnsupported) const
738{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100739 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100740}
741
742bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
743 const TensorInfo& output,
744 const L2NormalizationDescriptor& descriptor,
745 Optional<std::string&> reasonIfUnsupported) const
746{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100747 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100748 // Define supported types
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100749 std::array<DataType, 3> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100750 {
751 DataType::Float32,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100752 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100753 DataType::QuantisedSymm16
754 };
755
756 bool supported = true;
757
758 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
759 "Reference L2normalization: input type not supported.");
760
761 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
762 "Reference L2normalization: output type not supported.");
763
764 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
765 "Reference L2normalization: input and output types mismatched.");
766
767 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
768 "Reference L2normalization: input and output shapes have different "
769 "num total elements.");
770
771 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100772}
773
774bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
775 const TensorInfo& outputStateIn,
776 const TensorInfo& cellStateIn,
777 const TensorInfo& scratchBuffer,
778 const TensorInfo& outputStateOut,
779 const TensorInfo& cellStateOut,
780 const TensorInfo& output,
781 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100782 const LstmInputParamsInfo& paramsInfo,
783 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100784{
telsoa01c577f2c2018-08-31 09:22:23 +0100785 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100786 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100787
788 bool supported = true;
789
790 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100791 DataType::Float32,
792 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100793 };
794
Jan Eilersd01a83c2019-07-03 18:20:40 +0100795 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100796 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
797 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100798 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
799 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100800 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
801 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100802 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
803 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100804 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
805 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100806 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
807 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100808 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
809 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100810 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +0100811 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100812 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100813 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100814 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100815 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100816 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100817 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100818 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100819 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100820 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100821 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100822 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100823 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100824 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100825 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100826 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100827 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100828 "Reference Lstm: input and OutputGateBias types are mismatched");
829 if (!descriptor.m_CifgEnabled)
830 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100831 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100832 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100833 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100834 reasonIfUnsupported,
835 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100836 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100837 "Reference Lstm: input and InputGateBias types are mismatched");
838 if (descriptor.m_PeepholeEnabled)
839 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100840 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100841 reasonIfUnsupported,
842 "Reference Lstm: input and CellToInputWeights types are mismatched");
843 }
844 }
845 if (descriptor.m_PeepholeEnabled)
846 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100847 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100848 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100849 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100850 "Reference Lstm: input and CellToOutputWeights types are mismatched");
851 }
852 if (descriptor.m_ProjectionEnabled)
853 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100854 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100855 "Reference Lstm: input and mProjectionWeights types are mismatched");
856 if (paramsInfo.m_ProjectionBias != nullptr)
857 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100858 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100859 "Reference Lstm: input and ProjectionBias types are mismatched");
860 }
861 }
862 if (descriptor.m_LayerNormEnabled)
863 {
864 if (!descriptor.m_CifgEnabled)
865 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100866 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100867 reasonIfUnsupported,
868 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
869 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100870 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100871 reasonIfUnsupported,
872 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100873 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100874 reasonIfUnsupported,
875 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100876 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100877 reasonIfUnsupported,
878 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
879 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100880
881 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100882}
883
saoste012df12b32018-11-28 16:57:20 +0000884bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
885 const TensorInfo& input1,
886 const TensorInfo& output,
887 Optional<std::string&> reasonIfUnsupported) const
888{
Sadik Armagan2999a022019-04-09 14:20:12 +0100889 bool supported = true;
890
Sadik Armagan68db21f2019-08-09 16:44:10 +0100891 std::array<DataType,3> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100892 DataType::Float32,
893 DataType::QuantisedAsymm8,
894 DataType::QuantisedSymm16
895 };
896
897 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
898 "Reference maximum: input 0 is not a supported type.");
899
900 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
901 "Reference maximum: input 1 is not a supported type.");
902
903 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
904 "Reference maximum: output is not a supported type.");
905
906 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
907 "Reference maximum: input 0 and Input 1 types are mismatched");
908
909 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
910 "Reference maximum: input and output types are mismatched");
911
912 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
913 "Reference maximum: shapes are not suitable for implicit broadcast.");
914
915 return supported;
saoste012df12b32018-11-28 16:57:20 +0000916}
917
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100918bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
919 const TensorInfo& output,
920 const MeanDescriptor& descriptor,
921 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100922{
James Conroy4d1ff582019-06-10 17:06:39 +0100923 bool supported = true;
924 std::string meanLayerStr = "Mean";
925 std::string outputTensorStr = "output";
926
James Conroyb80775f2019-06-11 11:25:30 +0100927 std::array<DataType,3> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +0100928 {
929 DataType::Float32,
James Conroyb80775f2019-06-11 11:25:30 +0100930 DataType::QuantisedAsymm8,
931 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +0100932 };
933
934 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
935 "Reference Mean: input type not supported.");
936
937 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
938 "Reference Mean: input and output types are mismatched");
939
940 if (descriptor.m_KeepDims)
941 {
942 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
943 reasonIfUnsupported,
944 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
945 output.GetNumDimensions(),
946 meanLayerStr, outputTensorStr).data());
947 }
948 else if (descriptor.m_Axis.empty())
949 {
950 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
951 reasonIfUnsupported,
952 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
953 meanLayerStr, outputTensorStr).data());
954 }
955 else
956 {
957 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
958
959 if (outputDim > 0)
960 {
961 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
962 reasonIfUnsupported,
963 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
964 meanLayerStr, outputTensorStr).data());
965 }
966 else
967 {
968 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
969 reasonIfUnsupported,
970 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
971 meanLayerStr, outputTensorStr).data());
972 }
973 }
974
975 return supported;
narpra0132b90462018-09-13 11:07:48 +0100976}
977
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100978bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000979 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100980 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100981 Optional<std::string&> reasonIfUnsupported) const
982{
Jim Flynne242f2d2019-05-22 14:24:13 +0100983 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100984}
985
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000986bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
987 const TensorInfo &output,
988 Optional<std::string &> reasonIfUnsupported) const
989{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100990 bool supported = true;
991
992 std::array<DataType,5> supportedTypes =
993 {
994 DataType::Float32,
995 DataType::Float16,
996 DataType::QuantisedAsymm8,
997 DataType::QuantisedSymm16,
998 DataType::Boolean
999 };
1000
1001 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1002 "Reference MemCopy: input type not supported");
1003
1004 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1005 "Reference MemCopy: output type not supported");
1006
1007 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1008 "Reference MemCopy: input and output types are mismatched");
1009
1010 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001011}
1012
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001013bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1014 const TensorInfo& input1,
1015 const TensorInfo& output,
1016 Optional<std::string&> reasonIfUnsupported) const
1017{
Sadik Armagan2999a022019-04-09 14:20:12 +01001018 bool supported = true;
1019
Sadik Armagan68db21f2019-08-09 16:44:10 +01001020 std::array<DataType,3> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001021 DataType::Float32,
1022 DataType::QuantisedAsymm8,
1023 DataType::QuantisedSymm16
1024 };
1025
1026 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1027 "Reference minimum: input 0 is not a supported type.");
1028
1029 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1030 "Reference minimum: input 1 is not a supported type.");
1031
1032 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1033 "Reference minimum: output is not a supported type.");
1034
1035 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1036 "Reference minimum: input 0 and Input 1 types are mismatched");
1037
1038 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1039 "Reference minimum: input and output types are mismatched");
1040
1041 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1042 "Reference minimum: shapes are not suitable for implicit broadcast.");
1043
1044 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001045}
1046
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001047bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1048 const TensorInfo& input1,
1049 const TensorInfo& output,
1050 Optional<std::string&> reasonIfUnsupported) const
1051{
Sadik Armagan2999a022019-04-09 14:20:12 +01001052 bool supported = true;
1053
1054 std::array<DataType,3> supportedTypes = {
1055 DataType::Float32,
1056 DataType::QuantisedAsymm8,
1057 DataType::QuantisedSymm16
1058 };
1059
1060 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1061 "Reference multiplication: input 0 is not a supported type.");
1062
1063 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1064 "Reference multiplication: input 1 is not a supported type.");
1065
1066 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1067 "Reference multiplication: output is not a supported type.");
1068
1069 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1070 "Reference multiplication: input 0 and Input 1 types are mismatched");
1071
1072 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1073 "Reference multiplication: input and output types are mismatched");
1074
1075 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1076 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1077
1078 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001079}
1080
1081bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1082 const TensorInfo& output,
1083 const NormalizationDescriptor& descriptor,
1084 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001085{
Nina Drozd661dfa72018-10-02 11:14:17 +01001086 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001087
1088 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001089 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001090 {
1091 DataType::Float16,
1092 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001093 DataType::QuantisedAsymm8,
1094 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001095 };
1096
1097 bool supported = true;
1098
1099 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1100 "Reference normalization: input type not supported.");
1101
1102 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1103 "Reference normalization: output type not supported.");
1104
1105 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1106 "Reference normalization: input and output shapes have different "
1107 "num total elements.");
1108
1109 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001110}
1111
1112bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1113 Optional<std::string&> reasonIfUnsupported) const
1114{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001115 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001116}
1117
1118bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1119 const TensorInfo& output,
1120 const PadDescriptor& descriptor,
1121 Optional<std::string&> reasonIfUnsupported) const
1122{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001123 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001124 bool supported = true;
1125
1126 // Define supported output and inputs types.
1127 std::array<DataType,3> supportedTypes =
1128 {
1129 DataType::Float32,
1130 DataType::QuantisedAsymm8,
1131 DataType::QuantisedSymm16
1132 };
1133
1134 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1135 "Reference pad: input is not a supported type.");
1136
1137 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1138 "Reference pad: output is not a supported type.");
1139
1140 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1141 "Reference pad: input and output types are mismatched.");
1142
1143 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001144}
1145
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001146bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1147 const TensorInfo& output,
1148 const PermuteDescriptor& descriptor,
1149 Optional<std::string&> reasonIfUnsupported) const
1150{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001151 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001152 bool supported = true;
1153
1154 // Define supported output and inputs types.
1155 std::array<DataType,3> supportedTypes =
1156 {
1157 DataType::Float32,
1158 DataType::QuantisedAsymm8,
1159 DataType::QuantisedSymm16
1160 };
1161
1162 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1163 "Reference permute: input is not a supported type.");
1164
1165 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1166 "Reference permute: output is not a supported type.");
1167
1168 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1169 "Reference permute: input and output types are mismatched.");
1170
1171 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001172}
1173
1174bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1175 const TensorInfo& output,
1176 const Pooling2dDescriptor& descriptor,
1177 Optional<std::string&> reasonIfUnsupported) const
1178{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001179 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001180 bool supported = true;
1181
1182 // Define supported output and inputs types.
Teresa Charlin0434df62019-06-06 13:40:35 +01001183 std::array<DataType,3> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001184 {
1185 DataType::Float32,
Teresa Charlin0434df62019-06-06 13:40:35 +01001186 DataType::QuantisedAsymm8,
1187 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001188 };
1189
1190 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1191 "Reference poolind2d: input is not a supported type.");
1192
1193 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1194 "Reference poolind2d: output is not a supported type.");
1195
1196 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1197 "Reference poolind2d: input and output types are mismatched.");
1198
1199 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001200}
1201
Derek Lamberti5f400d62019-03-25 15:41:58 +00001202bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1203 const TensorInfo& output,
1204 Optional<std::string&> reasonIfUnsupported) const
1205{
1206 bool supported = true;
1207
1208 // Define supported output types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001209 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001210 DataType::Float32,
1211 };
1212
1213 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1214 "Reference quantize: input type not supported.");
1215
1216 // Define supported output types.
1217 std::array<DataType,2> supportedOutputTypes = {
1218 DataType::QuantisedAsymm8,
1219 DataType::QuantisedSymm16
1220 };
1221 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1222 "Reference quantize: output type not supported.");
1223
1224 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1225 "Reference quantize: input and output shapes have different num total elements.");
1226
1227 return supported;
1228}
1229
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001230bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001231 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001232 Optional<std::string&> reasonIfUnsupported) const
1233{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001234 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001235 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001236 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001237 {
1238 DataType::Float32,
1239 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001240 DataType::QuantisedAsymm8,
1241 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001242 };
1243 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1244 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001245}
1246
1247bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001248 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001249 Optional<std::string&> reasonIfUnsupported) const
1250{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001251 bool supported = true;
1252 std::array<DataType,3> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001253 {
1254 DataType::Float32,
1255 DataType::QuantisedAsymm8,
1256 DataType::QuantisedSymm16
1257 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001258
1259 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1260 "Reference ResizeBilinear: input type not supported");
1261
1262 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1263 "Reference ResizeBilinear: output type not supported");
1264
1265 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1266 "Reference ResizeBilinear: input and output types not matching");
1267
1268 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001269}
1270
Teresa Charlin970f43b2019-07-01 13:51:07 +01001271bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1272 const TensorInfo& output,
1273 const ResizeDescriptor& descriptor,
1274 Optional<std::string&> reasonIfUnsupported) const
1275{
1276 bool supported = true;
1277 std::array<DataType,3> supportedTypes =
1278 {
1279 DataType::Float32,
1280 DataType::QuantisedAsymm8,
1281 DataType::QuantisedSymm16
1282 };
1283
1284 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1285 "Reference Resize: input type not supported");
1286
1287 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1288 "Reference Resize: output type not supported");
1289
1290 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1291 "Reference Resize: input and output types not matching");
1292
1293 return supported;
1294}
1295
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001296bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1297 const TensorInfo& output,
1298 Optional<std::string&> reasonIfUnsupported) const
1299{
nikraj010421e7f2019-06-14 09:40:34 +01001300 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001301 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001302 {
1303 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001304 DataType::QuantisedAsymm8,
1305 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001306 };
1307
1308 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1309 "Reference rsqrt: input type not supported");
1310
1311 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1312 "Reference rsqrt: output type not supported");
1313
1314 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1315 "Reference rsqrt: input and output types not matching");
1316
1317 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1318 "Reference Rsqrt: input and output shapes have different number of total elements");
1319
1320 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001321}
1322
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001323bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1324 const TensorInfo& output,
1325 const SoftmaxDescriptor& descriptor,
1326 Optional<std::string&> reasonIfUnsupported) const
1327{
1328 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001329 bool supported = true;
1330 std::array<DataType,3> supportedTypes =
1331 {
1332 DataType::Float32,
1333 DataType::QuantisedAsymm8,
1334 DataType::QuantisedSymm16
1335 };
1336
1337 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1338 "Reference concatenation: output type not supported");
1339
1340 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1341 "Reference concatenation: input type not supported");
1342
1343 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1344 "Reference concatenation: input type not supported");
1345
1346 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001347}
1348
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001349bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1350 const TensorInfo& output,
1351 const SpaceToBatchNdDescriptor& descriptor,
1352 Optional<std::string&> reasonIfUnsupported) const
1353{
1354 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001355 bool supported = true;
1356 std::array<DataType,3> supportedTypes =
1357 {
1358 DataType::Float32,
1359 DataType::QuantisedAsymm8,
1360 DataType::QuantisedSymm16
1361 };
1362
1363 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1364 "Reference SpaceToBatchNd: input type not supported");
1365
1366 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1367 "Reference SpaceToBatchNd: output type not supported");
1368
1369 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1370 "Reference SpaceToBatchNd: input and output types are mismatched");
1371
1372 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001373}
1374
Keith Davisa57eccb2019-06-14 17:33:22 +01001375bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001376 const TensorInfo& output,
1377 const SpaceToDepthDescriptor& descriptor,
1378 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001379{
1380
1381 ignore_unused(descriptor);
1382 bool supported = true;
1383
James Conroyd2aa85e2019-07-01 17:12:40 +01001384 std::array<DataType,3> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001385 {
1386 DataType::Float32,
1387 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001388 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001389 };
1390
1391 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1392 "Reference SpaceToDepth: input type not supported");
1393
1394 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1395 "Reference SpaceToDepth: output type not supported");
1396
1397 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1398 "Reference SpaceToDepth: input and output types are mismatched");
1399
1400 return supported;
1401}
1402
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001403bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1404 const ViewsDescriptor& descriptor,
1405 Optional<std::string&> reasonIfUnsupported) const
1406{
1407 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001408 bool supported = true;
1409 std::array<DataType,3> supportedTypes =
1410 {
1411 DataType::Float32,
1412 DataType::QuantisedAsymm8,
1413 DataType::QuantisedSymm16
1414 };
1415
1416 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1417 "Reference splitter: input type not supported");
1418
1419 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001420}
1421
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001422bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1423 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1424 const ViewsDescriptor& descriptor,
1425 Optional<std::string&> reasonIfUnsupported) const
1426{
1427 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001428 bool supported = true;
1429 std::array<DataType,3> supportedTypes =
1430 {
1431 DataType::Float32,
1432 DataType::QuantisedAsymm8,
1433 DataType::QuantisedSymm16
1434 };
1435
1436 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1437 "Reference splitter: output type not supported");
1438 for (const TensorInfo output : outputs)
1439 {
1440 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1441 "Reference splitter: input type not supported");
1442
1443 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1444 "Reference splitter: input and output types mismatched.");
1445 }
1446
1447 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001448}
1449
Matthew Jackson81e601c2019-07-11 12:07:09 +01001450bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1451 const TensorInfo& output,
1452 const StackDescriptor& descriptor,
1453 Optional<std::string&> reasonIfUnsupported) const
1454{
1455 ignore_unused(descriptor);
1456
1457 bool supported = true;
1458 std::array<DataType,3> supportedTypes =
1459 {
1460 DataType::Float32,
1461 DataType::QuantisedAsymm8,
1462 DataType::QuantisedSymm16
1463 };
1464
1465 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1466 "Reference stack: output type not supported");
1467 for (const TensorInfo* input : inputs)
1468 {
1469 BOOST_ASSERT(input != nullptr);
1470 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1471 "Reference stack: input type not supported");
1472
1473 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1474 "Reference stack: input and output types mismatched.");
1475 }
1476
1477 return supported;
1478}
1479
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001480bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1481 const TensorInfo& output,
1482 const StridedSliceDescriptor& descriptor,
1483 Optional<std::string&> reasonIfUnsupported) const
1484{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001485 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001486 bool supported = true;
1487
1488 std::array<DataType,3> supportedTypes =
1489 {
1490 DataType::Float32,
1491 DataType::QuantisedAsymm8,
1492 DataType::QuantisedSymm16
1493 };
1494
1495 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1496 "Reference StridedSlice: input type not supported");
1497
1498 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1499 "Reference StridedSlice: output type not supported");
1500
1501 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1502 "Reference StridedSlice: input and output types are mismatched");
1503
1504 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001505}
1506
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001507bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1508 const TensorInfo& input1,
1509 const TensorInfo& output,
1510 Optional<std::string&> reasonIfUnsupported) const
1511{
Sadik Armagan2999a022019-04-09 14:20:12 +01001512 bool supported = true;
1513
1514 std::array<DataType,3> supportedTypes = {
1515 DataType::Float32,
1516 DataType::QuantisedAsymm8,
1517 DataType::QuantisedSymm16
1518 };
1519
1520 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1521 "Reference subtraction: input 0 is not a supported type.");
1522
1523 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1524 "Reference subtraction: input 1 is not a supported type.");
1525
1526 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1527 "Reference subtraction: output is not a supported type.");
1528
1529 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1530 "Reference subtraction: input 0 and Input 1 types are mismatched");
1531
1532 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1533 "Reference subtraction: input and output types are mismatched");
1534
1535 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1536 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1537
1538 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001539}
1540
Matteo Martincighab9e5252019-06-13 17:27:46 +01001541bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1542 const TensorInfo& alpha,
1543 const TensorInfo& output,
1544 Optional<std::string&> reasonIfUnsupported) const
1545{
1546 bool supported = true;
1547
1548 std::array<DataType, 3> supportedTypes
1549 {
1550 DataType::Float32,
1551 DataType::QuantisedAsymm8,
1552 DataType::QuantisedSymm16
1553 };
1554
1555 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1556 "PReLU: input is not a supported type.");
1557
1558 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1559 "PReLU: alpha is not a supported type.");
1560
1561 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1562 "PReLU: output is not a supported type.");
1563
1564 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1565 "PReLU: input, alpha and output types are mismatched");
1566
1567 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1568 "PReLU: shapes are not suitable for implicit broadcast");
1569
1570 return supported;
1571}
1572
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001573bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1574 const TensorInfo& output,
1575 const TransposeConvolution2dDescriptor& descriptor,
1576 const TensorInfo& weights,
1577 const Optional<TensorInfo>& biases,
1578 Optional<std::string&> reasonIfUnsupported) const
1579{
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001580 bool supported = true;
1581
1582 std::array<DataType,3> supportedTypes =
1583 {
1584 DataType::Float32,
1585 DataType::QuantisedAsymm8,
1586 DataType::QuantisedSymm16
1587 };
1588
1589 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1590 "Reference TransposeConvolution2d: input is not a supported type.");
1591
1592 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1593 "Reference TransposeConvolution2d: output is not a supported type.");
1594
1595 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1596 "Reference TransposeConvolution2d: weights is not a supported type.");
1597
1598 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1599 "Reference TransposeConvolution2d: input and output types mismatched.");
1600
1601 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1602 "Reference TransposeConvolution2d: input and weights types mismatched.");
1603
1604 if (biases.has_value())
1605 {
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001606 std::array<DataType,3> biasesSupportedTypes =
1607 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001608 DataType::Float32,
1609 DataType::Signed32
1610 };
1611 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1612 "Reference TransposeConvolution2d: biases is not a supported type.");
1613 }
1614
1615 return supported;
1616}
1617
arovir011c7c81b2018-10-08 11:34:28 +01001618} // namespace armnn