blob: 495896817544d3922fc3c90bea2a615f25122a43 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01009#include <DataLayoutIndexed.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +010012
telsoa014fcda012018-03-09 14:13:49 +000013#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000014#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
David Beck111b5d92018-11-12 14:59:37 +000016#include <backendsCommon/BackendRegistry.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010017#include <backendsCommon/LayerSupportRules.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010018#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010019
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
Derek Lamberti50db4e82019-03-13 14:16:15 +000022#include <vector>
23#include <algorithm>
24#include <array>
25
telsoa014fcda012018-03-09 14:13:49 +000026using namespace boost;
27
28namespace armnn
29{
30
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010031namespace
32{
33
34template<typename Float32Func, typename Uint8Func, typename ... Params>
35bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
36 DataType dataType,
37 Float32Func floatFuncPtr,
38 Uint8Func uint8FuncPtr,
39 Params&&... params)
40{
41 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
42 dataType,
43 &FalseFunc<Params...>,
44 floatFuncPtr,
45 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000046 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000047 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010048 std::forward<Params>(params)...);
49}
50
51} // anonymous namespace
52
James Conroy4d1ff582019-06-10 17:06:39 +010053namespace
54{
55
56std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
57 unsigned int actual,
58 std::string& layerStr,
59 std::string& tensorName)
60{
61 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
62 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
63
64 return errorMsg;
65}
66
67} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000068
Sadik Armagan9199e582019-09-05 17:35:31 +010069bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
70 Optional<std::string&> reasonIfUnsupported) const
71{
72 bool supported = true;
73 std::array<DataType,3> supportedTypes =
74 {
75 DataType::Float32,
76 DataType::QuantisedAsymm8,
77 DataType::QuantisedSymm16
78 };
79
80 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
81 "Reference abs: input type not supported");
82
83 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
84 "Reference abs: output type not supported");
85
86 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
87 "Reference abs: input and output types not matching");
88
89 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
90 "Reference abs: input and output shapes have different number of total elements");
91
92 return supported;
93}
94
arovir011c7c81b2018-10-08 11:34:28 +010095bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
96 const TensorInfo& output,
97 const ActivationDescriptor& descriptor,
98 Optional<std::string&> reasonIfUnsupported) const
99{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000100 bool supported = true;
101
102 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100103 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000104 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100105 DataType::Float16,
Teresa Charlin18515e22019-04-24 10:17:46 +0100106 DataType::QuantisedAsymm8,
107 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000108 };
109
110 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
111 "Reference activation: input type not supported.");
112
113 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
114 "Reference activation: output type not supported.");
115
116 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
117 "Reference activation: input and output types mismatched.");
118
119 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
120 "Reference activation: input and output shapes are of different rank.");
121
122
123 struct ActivationFunctionSupported : public Rule
124 {
125 ActivationFunctionSupported(const ActivationDescriptor& desc)
126 {
127 switch(desc.m_Function)
128 {
129 case ActivationFunction::Abs:
130 case ActivationFunction::BoundedReLu:
131 case ActivationFunction::LeakyReLu:
132 case ActivationFunction::Linear:
133 case ActivationFunction::ReLu:
134 case ActivationFunction::Sigmoid:
135 case ActivationFunction::SoftReLu:
136 case ActivationFunction::Sqrt:
137 case ActivationFunction::Square:
138 case ActivationFunction::TanH:
139 {
140 m_Res = true;
141 break;
142 }
143 default:
144 {
145 m_Res = false;
146 break;
147 }
148 }
149 }
150 };
151
152 // Function is supported
153 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
154 "Reference activation: function not supported.");
155
156 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100157}
158
159bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
160 const TensorInfo& input1,
161 const TensorInfo& output,
162 Optional<std::string&> reasonIfUnsupported) const
163{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000164 bool supported = true;
165
Matthew Jackson252df3a2019-09-11 09:19:18 +0100166 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000167 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100168 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100169 DataType::QuantisedAsymm8,
170 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000171 };
172
173 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
174 "Reference addition: input 0 is not a supported type.");
175
176 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
177 "Reference addition: input 1 is not a supported type.");
178
179 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
180 "Reference addition: output is not a supported type.");
181
182 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
183 "Reference addition: input 0 and Input 1 types are mismatched");
184
185 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
186 "Reference addition: input and output types are mismatched");
187
188 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
189 "Reference addition: shapes are not suitable for implicit broadcast.");
190
191 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100192}
193
194bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
195 const TensorInfo& output,
196 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100197 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100198 const TensorInfo& beta,
199 const TensorInfo& gamma,
200 const BatchNormalizationDescriptor& descriptor,
201 Optional<std::string&> reasonIfUnsupported) const
202{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100203 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100204
Matteo Martincighf5507132019-06-04 10:59:47 +0100205 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100206 {
207 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100208 DataType::QuantisedAsymm8,
209 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100210 };
211
212 bool supported = true;
213
214 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
215 "Reference batch normalization: input is not a supported type.");
216
217 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
218 "Reference batch normalization: output is not a supported type.");
219
220 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
221 "Reference batch normalization: input and output types are mismatched");
222
223 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
224 "Reference batch normalization: mean is not a supported type.");
225
226 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
227 "Reference batch normalization: variance is not a supported type.");
228
229 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
230 "Reference batch normalization: beta is not a supported type.");
231
232 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
233 "Reference batch normalization: gamma is not a supported type.");
234
235 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100236}
237
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000238bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
239 const TensorInfo& output,
240 const BatchToSpaceNdDescriptor& descriptor,
241 Optional<std::string&> reasonIfUnsupported) const
242{
243 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100244
245 bool supported = true;
246
247 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
248 std::string inputTensorStr = "input";
249 std::string outputTensorStr = "output";
250
251 // Define supported types.
252 std::array<DataType,3> supportedTypes =
253 {
254 DataType::Float32,
255 DataType::QuantisedAsymm8,
256 DataType::QuantisedSymm16
257 };
258
259 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
260 "Reference BatchToSpaceNd: input type not supported.");
261
262 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
263 "Reference BatchToSpaceNd: output type not supported.");
264
265 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
266 "Reference BatchToSpaceNd: input and output types mismatched.");
267
268 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
269 reasonIfUnsupported,
270 CreateIncorrectDimensionsErrorMsg(4,
271 output.GetNumDimensions(),
272 batchToSpaceNdLayerStr,
273 outputTensorStr).data());
274
275 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
276 reasonIfUnsupported,
277 CreateIncorrectDimensionsErrorMsg(4,
278 input.GetNumDimensions(),
279 batchToSpaceNdLayerStr,
280 inputTensorStr).data());
281
282 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000283}
284
Jim Flynn906f9462019-05-10 13:55:21 +0100285bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
286 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100287 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100288 Optional<std::string&> reasonIfUnsupported) const
289{
Jim Flynne242f2d2019-05-22 14:24:13 +0100290 ignore_unused(descriptor);
291
292 bool supported = true;
293 std::array<DataType,3> supportedTypes =
294 {
295 DataType::Float32,
296 DataType::QuantisedAsymm8,
297 DataType::QuantisedSymm16
298 };
299
300 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
301 "Reference concatenation: output type not supported");
302 for (const TensorInfo* input : inputs)
303 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100304 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100305 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
306 "Reference concatenation: input type not supported");
307
308 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
309 "Reference concatenation: input and output types mismatched.");
310 }
311
312 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100313}
314
arovir011c7c81b2018-10-08 11:34:28 +0100315bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
316 Optional<std::string&> reasonIfUnsupported) const
317{
Jim Flynne242f2d2019-05-22 14:24:13 +0100318 std::array<DataType,4> supportedTypes =
319 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100320 DataType::Float32,
321 DataType::Signed32,
322 DataType::QuantisedAsymm8,
323 DataType::QuantisedSymm16
324 };
325
326 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
327 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100328}
329
330bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
331 const TensorInfo& output,
332 Optional<std::string&> reasonIfUnsupported) const
333{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100334 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
335 input.GetDataType(),
336 &TrueFunc<>,
337 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000338 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000339 &FalseFuncI32<>,
340 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100341 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
342 output.GetDataType(),
343 &FalseOutputFuncF16<>,
344 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000345 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000346 &FalseFuncI32<>,
347 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100348}
349
350bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
351 const TensorInfo& output,
352 Optional<std::string&> reasonIfUnsupported) const
353{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100354 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
355 input.GetDataType(),
356 &FalseInputFuncF16<>,
357 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000358 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000359 &FalseFuncI32<>,
360 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100361 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
362 output.GetDataType(),
363 &TrueFunc<>,
364 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000365 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000366 &FalseFuncI32<>,
367 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100368}
369
370bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
371 const TensorInfo& output,
372 const Convolution2dDescriptor& descriptor,
373 const TensorInfo& weights,
374 const Optional<TensorInfo>& biases,
375 Optional<std::string&> reasonIfUnsupported) const
376{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100377 bool supported = true;
378
379 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100380 std::array<DataType,4> supportedTypes = {
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100381 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100382 DataType::Float16,
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100383 DataType::QuantisedAsymm8,
384 DataType::QuantisedSymm16
385 };
386
387 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100388 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100389
390 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100391 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100392
393 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100394 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100395
396 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100397 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100398
399 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100400 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100401
402 if (biases.has_value())
403 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100404 std::array<DataType,3> biasesSupportedTypes = {
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100405 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100406 DataType::Float16,
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100407 DataType::Signed32
408 };
409 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100410 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100411 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100412 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100413
414 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100415}
416
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000417bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
418 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000419 Optional<std::string&> reasonIfUnsupported) const
420{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100421 bool supported = true;
422
423 std::array<DataType,3> supportedTypes =
424 {
425 DataType::Float32,
426 DataType::QuantisedAsymm8,
427 DataType::QuantisedSymm16
428 };
429
430 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
431 "Reference debug: input type not supported");
432
433 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
434 "Reference debug: output type not supported");
435
436 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
437 "Reference debug: input and output types are mismatched");
438
439 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000440}
441
arovir011c7c81b2018-10-08 11:34:28 +0100442bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
443 const TensorInfo& output,
444 const DepthwiseConvolution2dDescriptor& descriptor,
445 const TensorInfo& weights,
446 const Optional<TensorInfo>& biases,
447 Optional<std::string&> reasonIfUnsupported) const
448{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100449 bool supported = true;
450
451 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100452 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100453 {
454 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100455 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100456 DataType::QuantisedAsymm8,
457 DataType::QuantisedSymm16
458 };
459
460 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
461 "Reference DepthwiseConvolution2d: input is not a supported type.");
462
463 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
464 "Reference DepthwiseConvolution2d: output is not a supported type.");
465
466 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
467 "Reference DepthwiseConvolution2d: weights is not a supported type.");
468
469 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
470 "Reference DepthwiseConvolution2d: input and output types mismatched.");
471
472 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
473 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
474
475 if (biases.has_value())
476 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100477 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100478 {
479 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100480 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100481 DataType::Signed32
482 };
483 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
484 "Reference DepthwiseConvolution2d: biases is not a supported type.");
485 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100486 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100487
488 return supported;
489
arovir011c7c81b2018-10-08 11:34:28 +0100490}
491
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000492bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
493 const TensorInfo& output,
494 Optional<std::string&> reasonIfUnsupported) const
495{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100496 bool supported = true;
497
498 std::array<DataType,2> supportedInputTypes = {
499 DataType::QuantisedAsymm8,
500 DataType::QuantisedSymm16
501 };
502
503 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
504 "Reference dequantize: input type not supported.");
505
Mike Kelly4992c342019-08-14 11:33:11 +0100506 std::array<DataType,1> supportedOutputTypes = {
507 DataType::Float32
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100508 };
509
510 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
511 "Reference dequantize: output type not supported.");
512
513 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
514 "Reference dequantize: input and output shapes have different num total elements.");
515
516 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000517}
518
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000519bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
520 const armnn::TensorInfo& input1,
521 const armnn::DetectionPostProcessDescriptor& descriptor,
522 armnn::Optional<std::string&> reasonIfUnsupported) const
523{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100524 bool supported = true;
525
Mike Kelly4992c342019-08-14 11:33:11 +0100526 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100527 {
528 DataType::Float32,
529 DataType::QuantisedAsymm8,
530 DataType::QuantisedSymm16
531 };
532
533 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
534 "Reference DetectionPostProcess: input 0 is not a supported type.");
535
536 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
537 "Reference DetectionPostProcess: input 1 is not a supported type.");
538
539 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000540}
541
Pablo Tellof0bd6832019-04-26 17:58:13 +0100542bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
543 const TensorInfo& output,
544 const DepthwiseConvolution2dDescriptor& descriptor,
545 const TensorInfo& weights,
546 const Optional<TensorInfo>& biases,
547 Optional<std::string&> reasonIfUnsupported) const
548{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100549 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100550}
551
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100552bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100553 const TensorInfo& input1,
554 const TensorInfo& output,
555 Optional<std::string&> reasonIfUnsupported) const
556{
Sadik Armagan2999a022019-04-09 14:20:12 +0100557 bool supported = true;
558
559 std::array<DataType,3> supportedTypes = {
560 DataType::Float32,
561 DataType::QuantisedAsymm8,
562 DataType::QuantisedSymm16
563 };
564
565 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
566 "Reference division: input 0 is not a supported type.");
567
568 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
569 "Reference division: input 1 is not a supported type.");
570
571 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
572 "Reference division: output is not a supported type.");
573
574 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
575 "Reference division: input 0 and Input 1 types are mismatched");
576
577 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
578 "Reference division: input and output types are mismatched");
579
580 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
581 "Reference division: shapes are not suitable for implicit broadcast.");
582
583 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100584}
585
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000586bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
587 const TensorInfo& input1,
588 const TensorInfo& output,
589 Optional<std::string&> reasonIfUnsupported) const
590{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100591 bool supported = true;
592
593 std::array<DataType,3> supportedTypes =
594 {
595 DataType::Float32,
596 DataType::QuantisedAsymm8,
597 DataType::QuantisedSymm16
598 };
599
600 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
601 "Reference equal: input 0 is not a supported type.");
602
603 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
604 "Reference equal: input 1 is not a supported type.");
605
606 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
607 "Reference equal: input 0 and Input 1 types are mismatched");
608
609 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
610 "Reference equal: shapes are not suitable for implicit broadcast.");
611
612 return supported;
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000613}
614
arovir011c7c81b2018-10-08 11:34:28 +0100615bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
616 const FakeQuantizationDescriptor& descriptor,
617 Optional<std::string&> reasonIfUnsupported) const
618{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100619 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100620 bool supported = true;
621
622 std::array<DataType,1> supportedTypes =
623 {
624 DataType::Float32
625 };
626
627 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
628 "Reference fake quantization: input type not supported.");
629
630 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100631}
632
633bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
634 const TensorInfo& output,
635 Optional<std::string&> reasonIfUnsupported) const
636{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100637 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100638 bool supported = true;
639
James Conroyb40d7102019-06-04 12:32:09 +0100640 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100641 {
James Conroyb40d7102019-06-04 12:32:09 +0100642 DataType::Float32,
643 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100644 };
645
646 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
647 "Reference Floor: input type not supported.");
648
649 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
650 "Reference Floor: output type not supported.");
651
652 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100653}
654
655bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
656 const TensorInfo& output,
657 const TensorInfo& weights,
658 const TensorInfo& biases,
659 const FullyConnectedDescriptor& descriptor,
660 Optional<std::string&> reasonIfUnsupported) const
661{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100662 bool supported = true;
663
664 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100665 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100666 {
667 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100668 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100669 DataType::QuantisedAsymm8,
670 DataType::QuantisedSymm16
671 };
672
673 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
674 "Reference Fully Connected: input type not supported.");
675
676 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
677 "Reference Fully Connected: output type not supported.");
678
679 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
680 "Reference Fully Connected: input and output types mismatched.");
681
682 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
683 "Reference Fully Connected: weights type not supported.");
684
685 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
686 "Reference Fully Connected: input and weight types mismatched.");
687
688 if (descriptor.m_BiasEnabled)
689 {
690 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100691 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100692 supportedBiasTypes =
693 {
694 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100695 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100696 DataType::Signed32
697 };
698
699 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
700 "Reference Fully Connected: bias type not supported.");
701
702 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
703 "Reference Fully Connected: bias and weight types mismatch.");
704
705 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
706 "Reference Fully Connected: bias type inferred from weights is incompatible.");
707
708 }
709
710 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100711}
712
narpra014951d842019-01-18 16:53:53 +0000713bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
714 const armnn::TensorInfo& input1,
715 const armnn::TensorInfo& output,
716 armnn::Optional<std::string&> reasonIfUnsupported) const
717{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100718 bool supported = true;
719 std::array<DataType,3> supportedTypes =
720 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100721 DataType::Float32,
722 DataType::QuantisedAsymm8,
723 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100724 };
725
726 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
727 "Reference Gather: input type not supported");
728
729 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
730 "Reference Gather: output type not supported");
731
732 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
733 "Reference Gather: indices (input1) type not supported");
734
735 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
736 "Reference Gather: input and output types not matching");
737
738 return supported;
narpra014951d842019-01-18 16:53:53 +0000739}
740
FrancisMurtagh878f0232018-12-19 10:56:15 +0000741bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
742 const TensorInfo& input1,
743 const TensorInfo& output,
744 Optional<std::string&> reasonIfUnsupported) const
745{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100746 bool supported = true;
747
748 std::array<DataType,3> supportedTypes =
749 {
750 DataType::Float32,
751 DataType::QuantisedAsymm8,
752 DataType::QuantisedSymm16
753 };
754
755 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
756 "Reference greater: input 0 is not a supported type.");
757
758 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
759 "Reference greater: input 1 is not a supported type.");
760
761 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
762 "Reference greater: input 0 and Input 1 types are mismatched");
763
764 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
765 "Reference greater: shapes are not suitable for implicit broadcast.");
766
767 return supported;
FrancisMurtagh878f0232018-12-19 10:56:15 +0000768}
769
arovir011c7c81b2018-10-08 11:34:28 +0100770bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
771 Optional<std::string&> reasonIfUnsupported) const
772{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100773 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100774}
775
776bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
777 const TensorInfo& output,
778 const L2NormalizationDescriptor& descriptor,
779 Optional<std::string&> reasonIfUnsupported) const
780{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100781 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100782 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100783 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100784 {
785 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100786 DataType::Float16,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100787 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100788 DataType::QuantisedSymm16
789 };
790
791 bool supported = true;
792
793 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
794 "Reference L2normalization: input type not supported.");
795
796 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
797 "Reference L2normalization: output type not supported.");
798
799 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
800 "Reference L2normalization: input and output types mismatched.");
801
802 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
803 "Reference L2normalization: input and output shapes have different "
804 "num total elements.");
805
806 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100807}
808
809bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
810 const TensorInfo& outputStateIn,
811 const TensorInfo& cellStateIn,
812 const TensorInfo& scratchBuffer,
813 const TensorInfo& outputStateOut,
814 const TensorInfo& cellStateOut,
815 const TensorInfo& output,
816 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100817 const LstmInputParamsInfo& paramsInfo,
818 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100819{
telsoa01c577f2c2018-08-31 09:22:23 +0100820 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100821 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100822
823 bool supported = true;
824
825 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100826 DataType::Float32,
827 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100828 };
829
Jan Eilersd01a83c2019-07-03 18:20:40 +0100830 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100831 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
832 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100833 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
834 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100835 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
836 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100837 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
838 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100839 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
840 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100841 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
842 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100843 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
844 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100845 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +0100846 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100847 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100848 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100849 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100850 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100851 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100852 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100853 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100854 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100855 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100856 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100857 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100858 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100859 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100860 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100861 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100862 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100863 "Reference Lstm: input and OutputGateBias types are mismatched");
864 if (!descriptor.m_CifgEnabled)
865 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100866 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100867 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100868 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100869 reasonIfUnsupported,
870 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100871 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100872 "Reference Lstm: input and InputGateBias types are mismatched");
873 if (descriptor.m_PeepholeEnabled)
874 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100875 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100876 reasonIfUnsupported,
877 "Reference Lstm: input and CellToInputWeights types are mismatched");
878 }
879 }
880 if (descriptor.m_PeepholeEnabled)
881 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100882 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100883 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100884 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100885 "Reference Lstm: input and CellToOutputWeights types are mismatched");
886 }
887 if (descriptor.m_ProjectionEnabled)
888 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100889 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100890 "Reference Lstm: input and mProjectionWeights types are mismatched");
891 if (paramsInfo.m_ProjectionBias != nullptr)
892 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100893 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100894 "Reference Lstm: input and ProjectionBias types are mismatched");
895 }
896 }
897 if (descriptor.m_LayerNormEnabled)
898 {
899 if (!descriptor.m_CifgEnabled)
900 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100901 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100902 reasonIfUnsupported,
903 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
904 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100905 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100906 reasonIfUnsupported,
907 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100908 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100909 reasonIfUnsupported,
910 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100911 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100912 reasonIfUnsupported,
913 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
914 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100915
916 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100917}
918
saoste012df12b32018-11-28 16:57:20 +0000919bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
920 const TensorInfo& input1,
921 const TensorInfo& output,
922 Optional<std::string&> reasonIfUnsupported) const
923{
Sadik Armagan2999a022019-04-09 14:20:12 +0100924 bool supported = true;
925
Sadik Armagan68db21f2019-08-09 16:44:10 +0100926 std::array<DataType,3> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100927 DataType::Float32,
928 DataType::QuantisedAsymm8,
929 DataType::QuantisedSymm16
930 };
931
932 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
933 "Reference maximum: input 0 is not a supported type.");
934
935 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
936 "Reference maximum: input 1 is not a supported type.");
937
938 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
939 "Reference maximum: output is not a supported type.");
940
941 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
942 "Reference maximum: input 0 and Input 1 types are mismatched");
943
944 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
945 "Reference maximum: input and output types are mismatched");
946
947 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
948 "Reference maximum: shapes are not suitable for implicit broadcast.");
949
950 return supported;
saoste012df12b32018-11-28 16:57:20 +0000951}
952
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100953bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
954 const TensorInfo& output,
955 const MeanDescriptor& descriptor,
956 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100957{
James Conroy4d1ff582019-06-10 17:06:39 +0100958 bool supported = true;
959 std::string meanLayerStr = "Mean";
960 std::string outputTensorStr = "output";
961
Matthew Jackson252df3a2019-09-11 09:19:18 +0100962 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +0100963 {
964 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100965 DataType::Float16,
James Conroyb80775f2019-06-11 11:25:30 +0100966 DataType::QuantisedAsymm8,
967 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +0100968 };
969
970 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
971 "Reference Mean: input type not supported.");
972
973 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
974 "Reference Mean: input and output types are mismatched");
975
976 if (descriptor.m_KeepDims)
977 {
978 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
979 reasonIfUnsupported,
980 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
981 output.GetNumDimensions(),
982 meanLayerStr, outputTensorStr).data());
983 }
984 else if (descriptor.m_Axis.empty())
985 {
986 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
987 reasonIfUnsupported,
988 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
989 meanLayerStr, outputTensorStr).data());
990 }
991 else
992 {
993 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
994
995 if (outputDim > 0)
996 {
997 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
998 reasonIfUnsupported,
999 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1000 meanLayerStr, outputTensorStr).data());
1001 }
1002 else
1003 {
1004 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1005 reasonIfUnsupported,
1006 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1007 meanLayerStr, outputTensorStr).data());
1008 }
1009 }
1010
1011 return supported;
narpra0132b90462018-09-13 11:07:48 +01001012}
1013
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001014bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001015 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001016 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001017 Optional<std::string&> reasonIfUnsupported) const
1018{
Jim Flynne242f2d2019-05-22 14:24:13 +01001019 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001020}
1021
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001022bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1023 const TensorInfo &output,
1024 Optional<std::string &> reasonIfUnsupported) const
1025{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001026 bool supported = true;
1027
1028 std::array<DataType,5> supportedTypes =
1029 {
1030 DataType::Float32,
1031 DataType::Float16,
1032 DataType::QuantisedAsymm8,
1033 DataType::QuantisedSymm16,
1034 DataType::Boolean
1035 };
1036
1037 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1038 "Reference MemCopy: input type not supported");
1039
1040 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1041 "Reference MemCopy: output type not supported");
1042
1043 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1044 "Reference MemCopy: input and output types are mismatched");
1045
1046 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001047}
1048
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001049bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1050 const TensorInfo& input1,
1051 const TensorInfo& output,
1052 Optional<std::string&> reasonIfUnsupported) const
1053{
Sadik Armagan2999a022019-04-09 14:20:12 +01001054 bool supported = true;
1055
Sadik Armagan68db21f2019-08-09 16:44:10 +01001056 std::array<DataType,3> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001057 DataType::Float32,
1058 DataType::QuantisedAsymm8,
1059 DataType::QuantisedSymm16
1060 };
1061
1062 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1063 "Reference minimum: input 0 is not a supported type.");
1064
1065 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1066 "Reference minimum: input 1 is not a supported type.");
1067
1068 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1069 "Reference minimum: output is not a supported type.");
1070
1071 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1072 "Reference minimum: input 0 and Input 1 types are mismatched");
1073
1074 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1075 "Reference minimum: input and output types are mismatched");
1076
1077 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1078 "Reference minimum: shapes are not suitable for implicit broadcast.");
1079
1080 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001081}
1082
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001083bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1084 const TensorInfo& input1,
1085 const TensorInfo& output,
1086 Optional<std::string&> reasonIfUnsupported) const
1087{
Sadik Armagan2999a022019-04-09 14:20:12 +01001088 bool supported = true;
1089
Matthew Jackson252df3a2019-09-11 09:19:18 +01001090 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001091 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001092 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001093 DataType::QuantisedAsymm8,
1094 DataType::QuantisedSymm16
1095 };
1096
1097 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1098 "Reference multiplication: input 0 is not a supported type.");
1099
1100 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1101 "Reference multiplication: input 1 is not a supported type.");
1102
1103 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1104 "Reference multiplication: output is not a supported type.");
1105
1106 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1107 "Reference multiplication: input 0 and Input 1 types are mismatched");
1108
1109 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1110 "Reference multiplication: input and output types are mismatched");
1111
1112 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1113 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1114
1115 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001116}
1117
1118bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1119 const TensorInfo& output,
1120 const NormalizationDescriptor& descriptor,
1121 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001122{
Nina Drozd661dfa72018-10-02 11:14:17 +01001123 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001124
1125 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001126 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001127 {
1128 DataType::Float16,
1129 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001130 DataType::QuantisedAsymm8,
1131 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001132 };
1133
1134 bool supported = true;
1135
1136 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1137 "Reference normalization: input type not supported.");
1138
1139 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1140 "Reference normalization: output type not supported.");
1141
1142 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1143 "Reference normalization: input and output shapes have different "
1144 "num total elements.");
1145
1146 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001147}
1148
1149bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1150 Optional<std::string&> reasonIfUnsupported) const
1151{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001152 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001153}
1154
1155bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1156 const TensorInfo& output,
1157 const PadDescriptor& descriptor,
1158 Optional<std::string&> reasonIfUnsupported) const
1159{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001160 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001161 bool supported = true;
1162
1163 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001164 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001165 {
1166 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001167 DataType::Float16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001168 DataType::QuantisedAsymm8,
1169 DataType::QuantisedSymm16
1170 };
1171
1172 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1173 "Reference pad: input is not a supported type.");
1174
1175 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1176 "Reference pad: output is not a supported type.");
1177
1178 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1179 "Reference pad: input and output types are mismatched.");
1180
1181 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001182}
1183
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001184bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1185 const TensorInfo& output,
1186 const PermuteDescriptor& descriptor,
1187 Optional<std::string&> reasonIfUnsupported) const
1188{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001189 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001190 bool supported = true;
1191
1192 // Define supported output and inputs types.
1193 std::array<DataType,3> supportedTypes =
1194 {
1195 DataType::Float32,
1196 DataType::QuantisedAsymm8,
1197 DataType::QuantisedSymm16
1198 };
1199
1200 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1201 "Reference permute: input is not a supported type.");
1202
1203 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1204 "Reference permute: output is not a supported type.");
1205
1206 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1207 "Reference permute: input and output types are mismatched.");
1208
1209 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001210}
1211
1212bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1213 const TensorInfo& output,
1214 const Pooling2dDescriptor& descriptor,
1215 Optional<std::string&> reasonIfUnsupported) const
1216{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001217 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001218 bool supported = true;
1219
1220 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001221 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001222 {
1223 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001224 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001225 DataType::QuantisedAsymm8,
1226 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001227 };
1228
1229 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1230 "Reference poolind2d: input is not a supported type.");
1231
1232 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1233 "Reference poolind2d: output is not a supported type.");
1234
1235 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1236 "Reference poolind2d: input and output types are mismatched.");
1237
1238 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001239}
1240
Derek Lamberti5f400d62019-03-25 15:41:58 +00001241bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1242 const TensorInfo& output,
1243 Optional<std::string&> reasonIfUnsupported) const
1244{
1245 bool supported = true;
1246
1247 // Define supported output types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001248 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001249 DataType::Float32,
1250 };
1251
1252 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1253 "Reference quantize: input type not supported.");
1254
1255 // Define supported output types.
1256 std::array<DataType,2> supportedOutputTypes = {
1257 DataType::QuantisedAsymm8,
1258 DataType::QuantisedSymm16
1259 };
1260 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1261 "Reference quantize: output type not supported.");
1262
1263 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1264 "Reference quantize: input and output shapes have different num total elements.");
1265
1266 return supported;
1267}
1268
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001269bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001270 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001271 Optional<std::string&> reasonIfUnsupported) const
1272{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001273 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001274 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001275 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001276 {
1277 DataType::Float32,
1278 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001279 DataType::QuantisedAsymm8,
1280 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001281 };
1282 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1283 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001284}
1285
1286bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001287 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001288 Optional<std::string&> reasonIfUnsupported) const
1289{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001290 bool supported = true;
1291 std::array<DataType,3> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001292 {
1293 DataType::Float32,
1294 DataType::QuantisedAsymm8,
1295 DataType::QuantisedSymm16
1296 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001297
1298 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1299 "Reference ResizeBilinear: input type not supported");
1300
1301 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1302 "Reference ResizeBilinear: output type not supported");
1303
1304 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1305 "Reference ResizeBilinear: input and output types not matching");
1306
1307 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001308}
1309
Teresa Charlin970f43b2019-07-01 13:51:07 +01001310bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1311 const TensorInfo& output,
1312 const ResizeDescriptor& descriptor,
1313 Optional<std::string&> reasonIfUnsupported) const
1314{
1315 bool supported = true;
1316 std::array<DataType,3> supportedTypes =
1317 {
1318 DataType::Float32,
1319 DataType::QuantisedAsymm8,
1320 DataType::QuantisedSymm16
1321 };
1322
1323 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1324 "Reference Resize: input type not supported");
1325
1326 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1327 "Reference Resize: output type not supported");
1328
1329 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1330 "Reference Resize: input and output types not matching");
1331
1332 return supported;
1333}
1334
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001335bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1336 const TensorInfo& output,
1337 Optional<std::string&> reasonIfUnsupported) const
1338{
nikraj010421e7f2019-06-14 09:40:34 +01001339 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001340 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001341 {
1342 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001343 DataType::QuantisedAsymm8,
1344 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001345 };
1346
1347 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1348 "Reference rsqrt: input type not supported");
1349
1350 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1351 "Reference rsqrt: output type not supported");
1352
1353 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1354 "Reference rsqrt: input and output types not matching");
1355
1356 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1357 "Reference Rsqrt: input and output shapes have different number of total elements");
1358
1359 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001360}
1361
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001362bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1363 const TensorInfo& output,
1364 const SoftmaxDescriptor& descriptor,
1365 Optional<std::string&> reasonIfUnsupported) const
1366{
1367 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001368 bool supported = true;
1369 std::array<DataType,3> supportedTypes =
1370 {
1371 DataType::Float32,
1372 DataType::QuantisedAsymm8,
1373 DataType::QuantisedSymm16
1374 };
1375
1376 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1377 "Reference concatenation: output type not supported");
1378
1379 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1380 "Reference concatenation: input type not supported");
1381
1382 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1383 "Reference concatenation: input type not supported");
1384
1385 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001386}
1387
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001388bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1389 const TensorInfo& output,
1390 const SpaceToBatchNdDescriptor& descriptor,
1391 Optional<std::string&> reasonIfUnsupported) const
1392{
1393 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001394 bool supported = true;
1395 std::array<DataType,3> supportedTypes =
1396 {
1397 DataType::Float32,
1398 DataType::QuantisedAsymm8,
1399 DataType::QuantisedSymm16
1400 };
1401
1402 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1403 "Reference SpaceToBatchNd: input type not supported");
1404
1405 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1406 "Reference SpaceToBatchNd: output type not supported");
1407
1408 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1409 "Reference SpaceToBatchNd: input and output types are mismatched");
1410
1411 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001412}
1413
Keith Davisa57eccb2019-06-14 17:33:22 +01001414bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001415 const TensorInfo& output,
1416 const SpaceToDepthDescriptor& descriptor,
1417 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001418{
1419
1420 ignore_unused(descriptor);
1421 bool supported = true;
1422
James Conroyd2aa85e2019-07-01 17:12:40 +01001423 std::array<DataType,3> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001424 {
1425 DataType::Float32,
1426 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001427 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001428 };
1429
1430 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1431 "Reference SpaceToDepth: input type not supported");
1432
1433 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1434 "Reference SpaceToDepth: output type not supported");
1435
1436 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1437 "Reference SpaceToDepth: input and output types are mismatched");
1438
1439 return supported;
1440}
1441
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001442bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1443 const ViewsDescriptor& descriptor,
1444 Optional<std::string&> reasonIfUnsupported) const
1445{
1446 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001447 bool supported = true;
1448 std::array<DataType,3> supportedTypes =
1449 {
1450 DataType::Float32,
1451 DataType::QuantisedAsymm8,
1452 DataType::QuantisedSymm16
1453 };
1454
1455 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1456 "Reference splitter: input type not supported");
1457
1458 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001459}
1460
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001461bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1462 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1463 const ViewsDescriptor& descriptor,
1464 Optional<std::string&> reasonIfUnsupported) const
1465{
1466 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001467 bool supported = true;
1468 std::array<DataType,3> supportedTypes =
1469 {
1470 DataType::Float32,
1471 DataType::QuantisedAsymm8,
1472 DataType::QuantisedSymm16
1473 };
1474
1475 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1476 "Reference splitter: output type not supported");
1477 for (const TensorInfo output : outputs)
1478 {
1479 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1480 "Reference splitter: input type not supported");
1481
1482 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1483 "Reference splitter: input and output types mismatched.");
1484 }
1485
1486 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001487}
1488
Matthew Jackson81e601c2019-07-11 12:07:09 +01001489bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1490 const TensorInfo& output,
1491 const StackDescriptor& descriptor,
1492 Optional<std::string&> reasonIfUnsupported) const
1493{
1494 ignore_unused(descriptor);
1495
1496 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001497 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001498 {
1499 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001500 DataType::Float16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001501 DataType::QuantisedAsymm8,
1502 DataType::QuantisedSymm16
1503 };
1504
1505 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1506 "Reference stack: output type not supported");
1507 for (const TensorInfo* input : inputs)
1508 {
1509 BOOST_ASSERT(input != nullptr);
1510 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1511 "Reference stack: input type not supported");
1512
1513 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1514 "Reference stack: input and output types mismatched.");
1515 }
1516
1517 return supported;
1518}
1519
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001520bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1521 const TensorInfo& output,
1522 const StridedSliceDescriptor& descriptor,
1523 Optional<std::string&> reasonIfUnsupported) const
1524{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001525 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001526 bool supported = true;
1527
1528 std::array<DataType,3> supportedTypes =
1529 {
1530 DataType::Float32,
1531 DataType::QuantisedAsymm8,
1532 DataType::QuantisedSymm16
1533 };
1534
1535 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1536 "Reference StridedSlice: input type not supported");
1537
1538 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1539 "Reference StridedSlice: output type not supported");
1540
1541 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1542 "Reference StridedSlice: input and output types are mismatched");
1543
1544 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001545}
1546
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001547bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1548 const TensorInfo& input1,
1549 const TensorInfo& output,
1550 Optional<std::string&> reasonIfUnsupported) const
1551{
Sadik Armagan2999a022019-04-09 14:20:12 +01001552 bool supported = true;
1553
1554 std::array<DataType,3> supportedTypes = {
1555 DataType::Float32,
1556 DataType::QuantisedAsymm8,
1557 DataType::QuantisedSymm16
1558 };
1559
1560 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1561 "Reference subtraction: input 0 is not a supported type.");
1562
1563 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1564 "Reference subtraction: input 1 is not a supported type.");
1565
1566 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1567 "Reference subtraction: output is not a supported type.");
1568
1569 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1570 "Reference subtraction: input 0 and Input 1 types are mismatched");
1571
1572 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1573 "Reference subtraction: input and output types are mismatched");
1574
1575 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1576 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1577
1578 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001579}
1580
Matteo Martincighab9e5252019-06-13 17:27:46 +01001581bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1582 const TensorInfo& alpha,
1583 const TensorInfo& output,
1584 Optional<std::string&> reasonIfUnsupported) const
1585{
1586 bool supported = true;
1587
1588 std::array<DataType, 3> supportedTypes
1589 {
1590 DataType::Float32,
1591 DataType::QuantisedAsymm8,
1592 DataType::QuantisedSymm16
1593 };
1594
1595 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1596 "PReLU: input is not a supported type.");
1597
1598 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1599 "PReLU: alpha is not a supported type.");
1600
1601 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1602 "PReLU: output is not a supported type.");
1603
1604 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1605 "PReLU: input, alpha and output types are mismatched");
1606
1607 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1608 "PReLU: shapes are not suitable for implicit broadcast");
1609
1610 return supported;
1611}
1612
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001613bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1614 const TensorInfo& output,
1615 const TransposeConvolution2dDescriptor& descriptor,
1616 const TensorInfo& weights,
1617 const Optional<TensorInfo>& biases,
1618 Optional<std::string&> reasonIfUnsupported) const
1619{
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001620 bool supported = true;
1621
Matthew Jackson252df3a2019-09-11 09:19:18 +01001622 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001623 {
1624 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001625 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001626 DataType::QuantisedAsymm8,
1627 DataType::QuantisedSymm16
1628 };
1629
1630 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1631 "Reference TransposeConvolution2d: input is not a supported type.");
1632
1633 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1634 "Reference TransposeConvolution2d: output is not a supported type.");
1635
1636 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1637 "Reference TransposeConvolution2d: weights is not a supported type.");
1638
1639 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1640 "Reference TransposeConvolution2d: input and output types mismatched.");
1641
1642 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1643 "Reference TransposeConvolution2d: input and weights types mismatched.");
1644
1645 if (biases.has_value())
1646 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001647 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001648 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001649 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001650 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001651 DataType::Signed32
1652 };
1653 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1654 "Reference TransposeConvolution2d: biases is not a supported type.");
1655 }
1656
1657 return supported;
1658}
1659
arovir011c7c81b2018-10-08 11:34:28 +01001660} // namespace armnn