blob: 5692f9e14327d7b75e4e9bc81bd69cf6927e380c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01009#include <DataLayoutIndexed.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +010012
telsoa014fcda012018-03-09 14:13:49 +000013#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000014#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
David Beck111b5d92018-11-12 14:59:37 +000016#include <backendsCommon/BackendRegistry.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010017#include <backendsCommon/LayerSupportRules.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010018#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010019
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
Derek Lamberti50db4e82019-03-13 14:16:15 +000022#include <vector>
23#include <algorithm>
24#include <array>
25
telsoa014fcda012018-03-09 14:13:49 +000026using namespace boost;
27
28namespace armnn
29{
30
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010031namespace
32{
33
34template<typename Float32Func, typename Uint8Func, typename ... Params>
35bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
36 DataType dataType,
37 Float32Func floatFuncPtr,
38 Uint8Func uint8FuncPtr,
39 Params&&... params)
40{
41 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
42 dataType,
43 &FalseFunc<Params...>,
44 floatFuncPtr,
45 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000046 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000047 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010048 std::forward<Params>(params)...);
49}
50
51} // anonymous namespace
52
James Conroy4d1ff582019-06-10 17:06:39 +010053namespace
54{
55
56std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
57 unsigned int actual,
58 std::string& layerStr,
59 std::string& tensorName)
60{
61 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
62 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
63
64 return errorMsg;
65}
66
67} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000068
Sadik Armagan9199e582019-09-05 17:35:31 +010069bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
70 Optional<std::string&> reasonIfUnsupported) const
71{
72 bool supported = true;
73 std::array<DataType,3> supportedTypes =
74 {
75 DataType::Float32,
76 DataType::QuantisedAsymm8,
77 DataType::QuantisedSymm16
78 };
79
80 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
81 "Reference abs: input type not supported");
82
83 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
84 "Reference abs: output type not supported");
85
86 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
87 "Reference abs: input and output types not matching");
88
89 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
90 "Reference abs: input and output shapes have different number of total elements");
91
92 return supported;
93}
94
arovir011c7c81b2018-10-08 11:34:28 +010095bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
96 const TensorInfo& output,
97 const ActivationDescriptor& descriptor,
98 Optional<std::string&> reasonIfUnsupported) const
99{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000100 bool supported = true;
101
102 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100103 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000104 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100105 DataType::QuantisedAsymm8,
106 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000107 };
108
109 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
110 "Reference activation: input type not supported.");
111
112 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
113 "Reference activation: output type not supported.");
114
115 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
116 "Reference activation: input and output types mismatched.");
117
118 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
119 "Reference activation: input and output shapes are of different rank.");
120
121
122 struct ActivationFunctionSupported : public Rule
123 {
124 ActivationFunctionSupported(const ActivationDescriptor& desc)
125 {
126 switch(desc.m_Function)
127 {
128 case ActivationFunction::Abs:
129 case ActivationFunction::BoundedReLu:
130 case ActivationFunction::LeakyReLu:
131 case ActivationFunction::Linear:
132 case ActivationFunction::ReLu:
133 case ActivationFunction::Sigmoid:
134 case ActivationFunction::SoftReLu:
135 case ActivationFunction::Sqrt:
136 case ActivationFunction::Square:
137 case ActivationFunction::TanH:
138 {
139 m_Res = true;
140 break;
141 }
142 default:
143 {
144 m_Res = false;
145 break;
146 }
147 }
148 }
149 };
150
151 // Function is supported
152 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
153 "Reference activation: function not supported.");
154
155 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100156}
157
158bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
159 const TensorInfo& input1,
160 const TensorInfo& output,
161 Optional<std::string&> reasonIfUnsupported) const
162{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000163 bool supported = true;
164
Sadik Armagan2999a022019-04-09 14:20:12 +0100165 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000166 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100167 DataType::QuantisedAsymm8,
168 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000169 };
170
171 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
172 "Reference addition: input 0 is not a supported type.");
173
174 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
175 "Reference addition: input 1 is not a supported type.");
176
177 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
178 "Reference addition: output is not a supported type.");
179
180 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
181 "Reference addition: input 0 and Input 1 types are mismatched");
182
183 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
184 "Reference addition: input and output types are mismatched");
185
186 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
187 "Reference addition: shapes are not suitable for implicit broadcast.");
188
189 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100190}
191
192bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
193 const TensorInfo& output,
194 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100195 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100196 const TensorInfo& beta,
197 const TensorInfo& gamma,
198 const BatchNormalizationDescriptor& descriptor,
199 Optional<std::string&> reasonIfUnsupported) const
200{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100201 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100202
Matteo Martincighf5507132019-06-04 10:59:47 +0100203 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100204 {
205 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100206 DataType::QuantisedAsymm8,
207 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100208 };
209
210 bool supported = true;
211
212 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
213 "Reference batch normalization: input is not a supported type.");
214
215 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
216 "Reference batch normalization: output is not a supported type.");
217
218 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
219 "Reference batch normalization: input and output types are mismatched");
220
221 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
222 "Reference batch normalization: mean is not a supported type.");
223
224 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
225 "Reference batch normalization: variance is not a supported type.");
226
227 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
228 "Reference batch normalization: beta is not a supported type.");
229
230 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
231 "Reference batch normalization: gamma is not a supported type.");
232
233 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100234}
235
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000236bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
237 const TensorInfo& output,
238 const BatchToSpaceNdDescriptor& descriptor,
239 Optional<std::string&> reasonIfUnsupported) const
240{
241 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100242
243 bool supported = true;
244
245 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
246 std::string inputTensorStr = "input";
247 std::string outputTensorStr = "output";
248
249 // Define supported types.
250 std::array<DataType,3> supportedTypes =
251 {
252 DataType::Float32,
253 DataType::QuantisedAsymm8,
254 DataType::QuantisedSymm16
255 };
256
257 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
258 "Reference BatchToSpaceNd: input type not supported.");
259
260 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
261 "Reference BatchToSpaceNd: output type not supported.");
262
263 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
264 "Reference BatchToSpaceNd: input and output types mismatched.");
265
266 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
267 reasonIfUnsupported,
268 CreateIncorrectDimensionsErrorMsg(4,
269 output.GetNumDimensions(),
270 batchToSpaceNdLayerStr,
271 outputTensorStr).data());
272
273 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
274 reasonIfUnsupported,
275 CreateIncorrectDimensionsErrorMsg(4,
276 input.GetNumDimensions(),
277 batchToSpaceNdLayerStr,
278 inputTensorStr).data());
279
280 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000281}
282
Jim Flynn906f9462019-05-10 13:55:21 +0100283bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
284 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100285 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100286 Optional<std::string&> reasonIfUnsupported) const
287{
Jim Flynne242f2d2019-05-22 14:24:13 +0100288 ignore_unused(descriptor);
289
290 bool supported = true;
291 std::array<DataType,3> supportedTypes =
292 {
293 DataType::Float32,
294 DataType::QuantisedAsymm8,
295 DataType::QuantisedSymm16
296 };
297
298 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
299 "Reference concatenation: output type not supported");
300 for (const TensorInfo* input : inputs)
301 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100302 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100303 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
304 "Reference concatenation: input type not supported");
305
306 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
307 "Reference concatenation: input and output types mismatched.");
308 }
309
310 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100311}
312
arovir011c7c81b2018-10-08 11:34:28 +0100313bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
314 Optional<std::string&> reasonIfUnsupported) const
315{
Jim Flynne242f2d2019-05-22 14:24:13 +0100316 std::array<DataType,4> supportedTypes =
317 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100318 DataType::Float32,
319 DataType::Signed32,
320 DataType::QuantisedAsymm8,
321 DataType::QuantisedSymm16
322 };
323
324 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
325 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100326}
327
328bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
329 const TensorInfo& output,
330 Optional<std::string&> reasonIfUnsupported) const
331{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100332 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
333 input.GetDataType(),
334 &TrueFunc<>,
335 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000336 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000337 &FalseFuncI32<>,
338 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100339 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
340 output.GetDataType(),
341 &FalseOutputFuncF16<>,
342 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000343 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000344 &FalseFuncI32<>,
345 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100346}
347
348bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
349 const TensorInfo& output,
350 Optional<std::string&> reasonIfUnsupported) const
351{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100352 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
353 input.GetDataType(),
354 &FalseInputFuncF16<>,
355 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000356 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000357 &FalseFuncI32<>,
358 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100359 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
360 output.GetDataType(),
361 &TrueFunc<>,
362 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000363 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000364 &FalseFuncI32<>,
365 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100366}
367
368bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
369 const TensorInfo& output,
370 const Convolution2dDescriptor& descriptor,
371 const TensorInfo& weights,
372 const Optional<TensorInfo>& biases,
373 Optional<std::string&> reasonIfUnsupported) const
374{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100375 bool supported = true;
376
377 // Define supported types.
378 std::array<DataType,3> supportedTypes = {
379 DataType::Float32,
380 DataType::QuantisedAsymm8,
381 DataType::QuantisedSymm16
382 };
383
384 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100385 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100386
387 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100388 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100389
390 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100391 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100392
393 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100394 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100395
396 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100397 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100398
399 if (biases.has_value())
400 {
Mike Kelly4992c342019-08-14 11:33:11 +0100401 std::array<DataType,2> biasesSupportedTypes = {
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100402 DataType::Float32,
403 DataType::Signed32
404 };
405 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100406 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100407 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100408 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100409
410 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100411}
412
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000413bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
414 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000415 Optional<std::string&> reasonIfUnsupported) const
416{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100417 bool supported = true;
418
419 std::array<DataType,3> supportedTypes =
420 {
421 DataType::Float32,
422 DataType::QuantisedAsymm8,
423 DataType::QuantisedSymm16
424 };
425
426 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
427 "Reference debug: input type not supported");
428
429 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
430 "Reference debug: output type not supported");
431
432 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
433 "Reference debug: input and output types are mismatched");
434
435 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000436}
437
arovir011c7c81b2018-10-08 11:34:28 +0100438bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
439 const TensorInfo& output,
440 const DepthwiseConvolution2dDescriptor& descriptor,
441 const TensorInfo& weights,
442 const Optional<TensorInfo>& biases,
443 Optional<std::string&> reasonIfUnsupported) const
444{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100445 bool supported = true;
446
447 // Define supported types.
448 std::array<DataType,3> supportedTypes =
449 {
450 DataType::Float32,
451 DataType::QuantisedAsymm8,
452 DataType::QuantisedSymm16
453 };
454
455 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
456 "Reference DepthwiseConvolution2d: input is not a supported type.");
457
458 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
459 "Reference DepthwiseConvolution2d: output is not a supported type.");
460
461 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
462 "Reference DepthwiseConvolution2d: weights is not a supported type.");
463
464 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
465 "Reference DepthwiseConvolution2d: input and output types mismatched.");
466
467 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
468 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
469
470 if (biases.has_value())
471 {
472 std::array<DataType,2> biasesSupportedTypes =
473 {
474 DataType::Float32,
475 DataType::Signed32
476 };
477 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
478 "Reference DepthwiseConvolution2d: biases is not a supported type.");
479 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100480 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100481
482 return supported;
483
arovir011c7c81b2018-10-08 11:34:28 +0100484}
485
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000486bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
487 const TensorInfo& output,
488 Optional<std::string&> reasonIfUnsupported) const
489{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100490 bool supported = true;
491
492 std::array<DataType,2> supportedInputTypes = {
493 DataType::QuantisedAsymm8,
494 DataType::QuantisedSymm16
495 };
496
497 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
498 "Reference dequantize: input type not supported.");
499
Mike Kelly4992c342019-08-14 11:33:11 +0100500 std::array<DataType,1> supportedOutputTypes = {
501 DataType::Float32
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100502 };
503
504 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
505 "Reference dequantize: output type not supported.");
506
507 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
508 "Reference dequantize: input and output shapes have different num total elements.");
509
510 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000511}
512
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000513bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
514 const armnn::TensorInfo& input1,
515 const armnn::DetectionPostProcessDescriptor& descriptor,
516 armnn::Optional<std::string&> reasonIfUnsupported) const
517{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100518 bool supported = true;
519
Mike Kelly4992c342019-08-14 11:33:11 +0100520 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100521 {
522 DataType::Float32,
523 DataType::QuantisedAsymm8,
524 DataType::QuantisedSymm16
525 };
526
527 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
528 "Reference DetectionPostProcess: input 0 is not a supported type.");
529
530 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
531 "Reference DetectionPostProcess: input 1 is not a supported type.");
532
533 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000534}
535
Pablo Tellof0bd6832019-04-26 17:58:13 +0100536bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
537 const TensorInfo& output,
538 const DepthwiseConvolution2dDescriptor& descriptor,
539 const TensorInfo& weights,
540 const Optional<TensorInfo>& biases,
541 Optional<std::string&> reasonIfUnsupported) const
542{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100543 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100544}
545
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100546bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100547 const TensorInfo& input1,
548 const TensorInfo& output,
549 Optional<std::string&> reasonIfUnsupported) const
550{
Sadik Armagan2999a022019-04-09 14:20:12 +0100551 bool supported = true;
552
553 std::array<DataType,3> supportedTypes = {
554 DataType::Float32,
555 DataType::QuantisedAsymm8,
556 DataType::QuantisedSymm16
557 };
558
559 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
560 "Reference division: input 0 is not a supported type.");
561
562 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
563 "Reference division: input 1 is not a supported type.");
564
565 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
566 "Reference division: output is not a supported type.");
567
568 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
569 "Reference division: input 0 and Input 1 types are mismatched");
570
571 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
572 "Reference division: input and output types are mismatched");
573
574 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
575 "Reference division: shapes are not suitable for implicit broadcast.");
576
577 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100578}
579
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000580bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
581 const TensorInfo& input1,
582 const TensorInfo& output,
583 Optional<std::string&> reasonIfUnsupported) const
584{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100585 bool supported = true;
586
587 std::array<DataType,3> supportedTypes =
588 {
589 DataType::Float32,
590 DataType::QuantisedAsymm8,
591 DataType::QuantisedSymm16
592 };
593
594 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
595 "Reference equal: input 0 is not a supported type.");
596
597 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
598 "Reference equal: input 1 is not a supported type.");
599
600 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
601 "Reference equal: input 0 and Input 1 types are mismatched");
602
603 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
604 "Reference equal: shapes are not suitable for implicit broadcast.");
605
606 return supported;
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000607}
608
arovir011c7c81b2018-10-08 11:34:28 +0100609bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
610 const FakeQuantizationDescriptor& descriptor,
611 Optional<std::string&> reasonIfUnsupported) const
612{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100613 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100614 bool supported = true;
615
616 std::array<DataType,1> supportedTypes =
617 {
618 DataType::Float32
619 };
620
621 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
622 "Reference fake quantization: input type not supported.");
623
624 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100625}
626
627bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
628 const TensorInfo& output,
629 Optional<std::string&> reasonIfUnsupported) const
630{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100631 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100632 bool supported = true;
633
James Conroyb40d7102019-06-04 12:32:09 +0100634 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100635 {
James Conroyb40d7102019-06-04 12:32:09 +0100636 DataType::Float32,
637 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100638 };
639
640 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
641 "Reference Floor: input type not supported.");
642
643 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
644 "Reference Floor: output type not supported.");
645
646 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100647}
648
649bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
650 const TensorInfo& output,
651 const TensorInfo& weights,
652 const TensorInfo& biases,
653 const FullyConnectedDescriptor& descriptor,
654 Optional<std::string&> reasonIfUnsupported) const
655{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100656 bool supported = true;
657
658 // Define supported types.
659 std::array<DataType,3> supportedTypes =
660 {
661 DataType::Float32,
662 DataType::QuantisedAsymm8,
663 DataType::QuantisedSymm16
664 };
665
666 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
667 "Reference Fully Connected: input type not supported.");
668
669 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
670 "Reference Fully Connected: output type not supported.");
671
672 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
673 "Reference Fully Connected: input and output types mismatched.");
674
675 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
676 "Reference Fully Connected: weights type not supported.");
677
678 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
679 "Reference Fully Connected: input and weight types mismatched.");
680
681 if (descriptor.m_BiasEnabled)
682 {
683 // Defined supported types for bias
684 std::array<DataType, 2>
685 supportedBiasTypes =
686 {
687 DataType::Float32,
688 DataType::Signed32
689 };
690
691 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
692 "Reference Fully Connected: bias type not supported.");
693
694 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
695 "Reference Fully Connected: bias and weight types mismatch.");
696
697 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
698 "Reference Fully Connected: bias type inferred from weights is incompatible.");
699
700 }
701
702 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100703}
704
narpra014951d842019-01-18 16:53:53 +0000705bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
706 const armnn::TensorInfo& input1,
707 const armnn::TensorInfo& output,
708 armnn::Optional<std::string&> reasonIfUnsupported) const
709{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100710 bool supported = true;
711 std::array<DataType,3> supportedTypes =
712 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100713 DataType::Float32,
714 DataType::QuantisedAsymm8,
715 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100716 };
717
718 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
719 "Reference Gather: input type not supported");
720
721 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
722 "Reference Gather: output type not supported");
723
724 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
725 "Reference Gather: indices (input1) type not supported");
726
727 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
728 "Reference Gather: input and output types not matching");
729
730 return supported;
narpra014951d842019-01-18 16:53:53 +0000731}
732
FrancisMurtagh878f0232018-12-19 10:56:15 +0000733bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
734 const TensorInfo& input1,
735 const TensorInfo& output,
736 Optional<std::string&> reasonIfUnsupported) const
737{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100738 bool supported = true;
739
740 std::array<DataType,3> supportedTypes =
741 {
742 DataType::Float32,
743 DataType::QuantisedAsymm8,
744 DataType::QuantisedSymm16
745 };
746
747 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
748 "Reference greater: input 0 is not a supported type.");
749
750 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
751 "Reference greater: input 1 is not a supported type.");
752
753 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
754 "Reference greater: input 0 and Input 1 types are mismatched");
755
756 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
757 "Reference greater: shapes are not suitable for implicit broadcast.");
758
759 return supported;
FrancisMurtagh878f0232018-12-19 10:56:15 +0000760}
761
arovir011c7c81b2018-10-08 11:34:28 +0100762bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
763 Optional<std::string&> reasonIfUnsupported) const
764{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100765 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100766}
767
768bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
769 const TensorInfo& output,
770 const L2NormalizationDescriptor& descriptor,
771 Optional<std::string&> reasonIfUnsupported) const
772{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100773 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100774 // Define supported types
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100775 std::array<DataType, 3> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100776 {
777 DataType::Float32,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100778 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100779 DataType::QuantisedSymm16
780 };
781
782 bool supported = true;
783
784 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
785 "Reference L2normalization: input type not supported.");
786
787 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
788 "Reference L2normalization: output type not supported.");
789
790 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
791 "Reference L2normalization: input and output types mismatched.");
792
793 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
794 "Reference L2normalization: input and output shapes have different "
795 "num total elements.");
796
797 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100798}
799
800bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
801 const TensorInfo& outputStateIn,
802 const TensorInfo& cellStateIn,
803 const TensorInfo& scratchBuffer,
804 const TensorInfo& outputStateOut,
805 const TensorInfo& cellStateOut,
806 const TensorInfo& output,
807 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100808 const LstmInputParamsInfo& paramsInfo,
809 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100810{
telsoa01c577f2c2018-08-31 09:22:23 +0100811 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100812 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100813
814 bool supported = true;
815
816 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100817 DataType::Float32,
818 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100819 };
820
Jan Eilersd01a83c2019-07-03 18:20:40 +0100821 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100822 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
823 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100824 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
825 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100826 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
827 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100828 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
829 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100830 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
831 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100832 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
833 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100834 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
835 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100836 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +0100837 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100838 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100839 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100840 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100841 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100842 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100843 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100844 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100845 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100846 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100847 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100848 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100849 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100850 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100851 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100852 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100853 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100854 "Reference Lstm: input and OutputGateBias types are mismatched");
855 if (!descriptor.m_CifgEnabled)
856 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100857 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100858 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100859 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100860 reasonIfUnsupported,
861 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100862 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100863 "Reference Lstm: input and InputGateBias types are mismatched");
864 if (descriptor.m_PeepholeEnabled)
865 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100866 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100867 reasonIfUnsupported,
868 "Reference Lstm: input and CellToInputWeights types are mismatched");
869 }
870 }
871 if (descriptor.m_PeepholeEnabled)
872 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100873 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100874 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100875 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100876 "Reference Lstm: input and CellToOutputWeights types are mismatched");
877 }
878 if (descriptor.m_ProjectionEnabled)
879 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100880 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100881 "Reference Lstm: input and mProjectionWeights types are mismatched");
882 if (paramsInfo.m_ProjectionBias != nullptr)
883 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100884 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100885 "Reference Lstm: input and ProjectionBias types are mismatched");
886 }
887 }
888 if (descriptor.m_LayerNormEnabled)
889 {
890 if (!descriptor.m_CifgEnabled)
891 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100892 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100893 reasonIfUnsupported,
894 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
895 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100896 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100897 reasonIfUnsupported,
898 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100899 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100900 reasonIfUnsupported,
901 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100902 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100903 reasonIfUnsupported,
904 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
905 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100906
907 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100908}
909
saoste012df12b32018-11-28 16:57:20 +0000910bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
911 const TensorInfo& input1,
912 const TensorInfo& output,
913 Optional<std::string&> reasonIfUnsupported) const
914{
Sadik Armagan2999a022019-04-09 14:20:12 +0100915 bool supported = true;
916
Sadik Armagan68db21f2019-08-09 16:44:10 +0100917 std::array<DataType,3> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100918 DataType::Float32,
919 DataType::QuantisedAsymm8,
920 DataType::QuantisedSymm16
921 };
922
923 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
924 "Reference maximum: input 0 is not a supported type.");
925
926 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
927 "Reference maximum: input 1 is not a supported type.");
928
929 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
930 "Reference maximum: output is not a supported type.");
931
932 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
933 "Reference maximum: input 0 and Input 1 types are mismatched");
934
935 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
936 "Reference maximum: input and output types are mismatched");
937
938 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
939 "Reference maximum: shapes are not suitable for implicit broadcast.");
940
941 return supported;
saoste012df12b32018-11-28 16:57:20 +0000942}
943
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100944bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
945 const TensorInfo& output,
946 const MeanDescriptor& descriptor,
947 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100948{
James Conroy4d1ff582019-06-10 17:06:39 +0100949 bool supported = true;
950 std::string meanLayerStr = "Mean";
951 std::string outputTensorStr = "output";
952
James Conroyb80775f2019-06-11 11:25:30 +0100953 std::array<DataType,3> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +0100954 {
955 DataType::Float32,
James Conroyb80775f2019-06-11 11:25:30 +0100956 DataType::QuantisedAsymm8,
957 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +0100958 };
959
960 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
961 "Reference Mean: input type not supported.");
962
963 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
964 "Reference Mean: input and output types are mismatched");
965
966 if (descriptor.m_KeepDims)
967 {
968 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
969 reasonIfUnsupported,
970 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
971 output.GetNumDimensions(),
972 meanLayerStr, outputTensorStr).data());
973 }
974 else if (descriptor.m_Axis.empty())
975 {
976 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
977 reasonIfUnsupported,
978 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
979 meanLayerStr, outputTensorStr).data());
980 }
981 else
982 {
983 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
984
985 if (outputDim > 0)
986 {
987 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
988 reasonIfUnsupported,
989 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
990 meanLayerStr, outputTensorStr).data());
991 }
992 else
993 {
994 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
995 reasonIfUnsupported,
996 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
997 meanLayerStr, outputTensorStr).data());
998 }
999 }
1000
1001 return supported;
narpra0132b90462018-09-13 11:07:48 +01001002}
1003
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001004bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001005 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001006 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001007 Optional<std::string&> reasonIfUnsupported) const
1008{
Jim Flynne242f2d2019-05-22 14:24:13 +01001009 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001010}
1011
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001012bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1013 const TensorInfo &output,
1014 Optional<std::string &> reasonIfUnsupported) const
1015{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001016 bool supported = true;
1017
1018 std::array<DataType,5> supportedTypes =
1019 {
1020 DataType::Float32,
1021 DataType::Float16,
1022 DataType::QuantisedAsymm8,
1023 DataType::QuantisedSymm16,
1024 DataType::Boolean
1025 };
1026
1027 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1028 "Reference MemCopy: input type not supported");
1029
1030 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1031 "Reference MemCopy: output type not supported");
1032
1033 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1034 "Reference MemCopy: input and output types are mismatched");
1035
1036 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001037}
1038
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001039bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1040 const TensorInfo& input1,
1041 const TensorInfo& output,
1042 Optional<std::string&> reasonIfUnsupported) const
1043{
Sadik Armagan2999a022019-04-09 14:20:12 +01001044 bool supported = true;
1045
Sadik Armagan68db21f2019-08-09 16:44:10 +01001046 std::array<DataType,3> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001047 DataType::Float32,
1048 DataType::QuantisedAsymm8,
1049 DataType::QuantisedSymm16
1050 };
1051
1052 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1053 "Reference minimum: input 0 is not a supported type.");
1054
1055 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1056 "Reference minimum: input 1 is not a supported type.");
1057
1058 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1059 "Reference minimum: output is not a supported type.");
1060
1061 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1062 "Reference minimum: input 0 and Input 1 types are mismatched");
1063
1064 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1065 "Reference minimum: input and output types are mismatched");
1066
1067 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1068 "Reference minimum: shapes are not suitable for implicit broadcast.");
1069
1070 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001071}
1072
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001073bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1074 const TensorInfo& input1,
1075 const TensorInfo& output,
1076 Optional<std::string&> reasonIfUnsupported) const
1077{
Sadik Armagan2999a022019-04-09 14:20:12 +01001078 bool supported = true;
1079
1080 std::array<DataType,3> supportedTypes = {
1081 DataType::Float32,
1082 DataType::QuantisedAsymm8,
1083 DataType::QuantisedSymm16
1084 };
1085
1086 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1087 "Reference multiplication: input 0 is not a supported type.");
1088
1089 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1090 "Reference multiplication: input 1 is not a supported type.");
1091
1092 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1093 "Reference multiplication: output is not a supported type.");
1094
1095 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1096 "Reference multiplication: input 0 and Input 1 types are mismatched");
1097
1098 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1099 "Reference multiplication: input and output types are mismatched");
1100
1101 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1102 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1103
1104 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001105}
1106
1107bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1108 const TensorInfo& output,
1109 const NormalizationDescriptor& descriptor,
1110 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001111{
Nina Drozd661dfa72018-10-02 11:14:17 +01001112 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001113
1114 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001115 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001116 {
1117 DataType::Float16,
1118 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001119 DataType::QuantisedAsymm8,
1120 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001121 };
1122
1123 bool supported = true;
1124
1125 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1126 "Reference normalization: input type not supported.");
1127
1128 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1129 "Reference normalization: output type not supported.");
1130
1131 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1132 "Reference normalization: input and output shapes have different "
1133 "num total elements.");
1134
1135 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001136}
1137
1138bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1139 Optional<std::string&> reasonIfUnsupported) const
1140{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001141 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001142}
1143
1144bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1145 const TensorInfo& output,
1146 const PadDescriptor& descriptor,
1147 Optional<std::string&> reasonIfUnsupported) const
1148{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001149 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001150 bool supported = true;
1151
1152 // Define supported output and inputs types.
1153 std::array<DataType,3> supportedTypes =
1154 {
1155 DataType::Float32,
1156 DataType::QuantisedAsymm8,
1157 DataType::QuantisedSymm16
1158 };
1159
1160 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1161 "Reference pad: input is not a supported type.");
1162
1163 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1164 "Reference pad: output is not a supported type.");
1165
1166 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1167 "Reference pad: input and output types are mismatched.");
1168
1169 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001170}
1171
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001172bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1173 const TensorInfo& output,
1174 const PermuteDescriptor& descriptor,
1175 Optional<std::string&> reasonIfUnsupported) const
1176{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001177 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001178 bool supported = true;
1179
1180 // Define supported output and inputs types.
1181 std::array<DataType,3> supportedTypes =
1182 {
1183 DataType::Float32,
1184 DataType::QuantisedAsymm8,
1185 DataType::QuantisedSymm16
1186 };
1187
1188 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1189 "Reference permute: input is not a supported type.");
1190
1191 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1192 "Reference permute: output is not a supported type.");
1193
1194 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1195 "Reference permute: input and output types are mismatched.");
1196
1197 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001198}
1199
1200bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1201 const TensorInfo& output,
1202 const Pooling2dDescriptor& descriptor,
1203 Optional<std::string&> reasonIfUnsupported) const
1204{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001205 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001206 bool supported = true;
1207
1208 // Define supported output and inputs types.
Teresa Charlin0434df62019-06-06 13:40:35 +01001209 std::array<DataType,3> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001210 {
1211 DataType::Float32,
Teresa Charlin0434df62019-06-06 13:40:35 +01001212 DataType::QuantisedAsymm8,
1213 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001214 };
1215
1216 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1217 "Reference poolind2d: input is not a supported type.");
1218
1219 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1220 "Reference poolind2d: output is not a supported type.");
1221
1222 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1223 "Reference poolind2d: input and output types are mismatched.");
1224
1225 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001226}
1227
Derek Lamberti5f400d62019-03-25 15:41:58 +00001228bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1229 const TensorInfo& output,
1230 Optional<std::string&> reasonIfUnsupported) const
1231{
1232 bool supported = true;
1233
1234 // Define supported output types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001235 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001236 DataType::Float32,
1237 };
1238
1239 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1240 "Reference quantize: input type not supported.");
1241
1242 // Define supported output types.
1243 std::array<DataType,2> supportedOutputTypes = {
1244 DataType::QuantisedAsymm8,
1245 DataType::QuantisedSymm16
1246 };
1247 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1248 "Reference quantize: output type not supported.");
1249
1250 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1251 "Reference quantize: input and output shapes have different num total elements.");
1252
1253 return supported;
1254}
1255
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001256bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001257 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001258 Optional<std::string&> reasonIfUnsupported) const
1259{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001260 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001261 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001262 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001263 {
1264 DataType::Float32,
1265 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001266 DataType::QuantisedAsymm8,
1267 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001268 };
1269 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1270 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001271}
1272
1273bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001274 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001275 Optional<std::string&> reasonIfUnsupported) const
1276{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001277 bool supported = true;
1278 std::array<DataType,3> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001279 {
1280 DataType::Float32,
1281 DataType::QuantisedAsymm8,
1282 DataType::QuantisedSymm16
1283 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001284
1285 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1286 "Reference ResizeBilinear: input type not supported");
1287
1288 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1289 "Reference ResizeBilinear: output type not supported");
1290
1291 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1292 "Reference ResizeBilinear: input and output types not matching");
1293
1294 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001295}
1296
Teresa Charlin970f43b2019-07-01 13:51:07 +01001297bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1298 const TensorInfo& output,
1299 const ResizeDescriptor& descriptor,
1300 Optional<std::string&> reasonIfUnsupported) const
1301{
1302 bool supported = true;
1303 std::array<DataType,3> supportedTypes =
1304 {
1305 DataType::Float32,
1306 DataType::QuantisedAsymm8,
1307 DataType::QuantisedSymm16
1308 };
1309
1310 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1311 "Reference Resize: input type not supported");
1312
1313 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1314 "Reference Resize: output type not supported");
1315
1316 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1317 "Reference Resize: input and output types not matching");
1318
1319 return supported;
1320}
1321
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001322bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1323 const TensorInfo& output,
1324 Optional<std::string&> reasonIfUnsupported) const
1325{
nikraj010421e7f2019-06-14 09:40:34 +01001326 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001327 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001328 {
1329 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001330 DataType::QuantisedAsymm8,
1331 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001332 };
1333
1334 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1335 "Reference rsqrt: input type not supported");
1336
1337 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1338 "Reference rsqrt: output type not supported");
1339
1340 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1341 "Reference rsqrt: input and output types not matching");
1342
1343 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1344 "Reference Rsqrt: input and output shapes have different number of total elements");
1345
1346 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001347}
1348
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001349bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1350 const TensorInfo& output,
1351 const SoftmaxDescriptor& descriptor,
1352 Optional<std::string&> reasonIfUnsupported) const
1353{
1354 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001355 bool supported = true;
1356 std::array<DataType,3> supportedTypes =
1357 {
1358 DataType::Float32,
1359 DataType::QuantisedAsymm8,
1360 DataType::QuantisedSymm16
1361 };
1362
1363 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1364 "Reference concatenation: output type not supported");
1365
1366 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1367 "Reference concatenation: input type not supported");
1368
1369 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1370 "Reference concatenation: input type not supported");
1371
1372 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001373}
1374
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001375bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1376 const TensorInfo& output,
1377 const SpaceToBatchNdDescriptor& descriptor,
1378 Optional<std::string&> reasonIfUnsupported) const
1379{
1380 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001381 bool supported = true;
1382 std::array<DataType,3> supportedTypes =
1383 {
1384 DataType::Float32,
1385 DataType::QuantisedAsymm8,
1386 DataType::QuantisedSymm16
1387 };
1388
1389 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1390 "Reference SpaceToBatchNd: input type not supported");
1391
1392 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1393 "Reference SpaceToBatchNd: output type not supported");
1394
1395 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1396 "Reference SpaceToBatchNd: input and output types are mismatched");
1397
1398 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001399}
1400
Keith Davisa57eccb2019-06-14 17:33:22 +01001401bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001402 const TensorInfo& output,
1403 const SpaceToDepthDescriptor& descriptor,
1404 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001405{
1406
1407 ignore_unused(descriptor);
1408 bool supported = true;
1409
James Conroyd2aa85e2019-07-01 17:12:40 +01001410 std::array<DataType,3> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001411 {
1412 DataType::Float32,
1413 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001414 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001415 };
1416
1417 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1418 "Reference SpaceToDepth: input type not supported");
1419
1420 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1421 "Reference SpaceToDepth: output type not supported");
1422
1423 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1424 "Reference SpaceToDepth: input and output types are mismatched");
1425
1426 return supported;
1427}
1428
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001429bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1430 const ViewsDescriptor& descriptor,
1431 Optional<std::string&> reasonIfUnsupported) const
1432{
1433 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001434 bool supported = true;
1435 std::array<DataType,3> supportedTypes =
1436 {
1437 DataType::Float32,
1438 DataType::QuantisedAsymm8,
1439 DataType::QuantisedSymm16
1440 };
1441
1442 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1443 "Reference splitter: input type not supported");
1444
1445 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001446}
1447
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001448bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1449 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1450 const ViewsDescriptor& descriptor,
1451 Optional<std::string&> reasonIfUnsupported) const
1452{
1453 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001454 bool supported = true;
1455 std::array<DataType,3> supportedTypes =
1456 {
1457 DataType::Float32,
1458 DataType::QuantisedAsymm8,
1459 DataType::QuantisedSymm16
1460 };
1461
1462 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1463 "Reference splitter: output type not supported");
1464 for (const TensorInfo output : outputs)
1465 {
1466 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1467 "Reference splitter: input type not supported");
1468
1469 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1470 "Reference splitter: input and output types mismatched.");
1471 }
1472
1473 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001474}
1475
Matthew Jackson81e601c2019-07-11 12:07:09 +01001476bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1477 const TensorInfo& output,
1478 const StackDescriptor& descriptor,
1479 Optional<std::string&> reasonIfUnsupported) const
1480{
1481 ignore_unused(descriptor);
1482
1483 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001484 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001485 {
1486 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001487 DataType::Float16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001488 DataType::QuantisedAsymm8,
1489 DataType::QuantisedSymm16
1490 };
1491
1492 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1493 "Reference stack: output type not supported");
1494 for (const TensorInfo* input : inputs)
1495 {
1496 BOOST_ASSERT(input != nullptr);
1497 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1498 "Reference stack: input type not supported");
1499
1500 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1501 "Reference stack: input and output types mismatched.");
1502 }
1503
1504 return supported;
1505}
1506
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001507bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1508 const TensorInfo& output,
1509 const StridedSliceDescriptor& descriptor,
1510 Optional<std::string&> reasonIfUnsupported) const
1511{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001512 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001513 bool supported = true;
1514
1515 std::array<DataType,3> supportedTypes =
1516 {
1517 DataType::Float32,
1518 DataType::QuantisedAsymm8,
1519 DataType::QuantisedSymm16
1520 };
1521
1522 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1523 "Reference StridedSlice: input type not supported");
1524
1525 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1526 "Reference StridedSlice: output type not supported");
1527
1528 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1529 "Reference StridedSlice: input and output types are mismatched");
1530
1531 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001532}
1533
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001534bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1535 const TensorInfo& input1,
1536 const TensorInfo& output,
1537 Optional<std::string&> reasonIfUnsupported) const
1538{
Sadik Armagan2999a022019-04-09 14:20:12 +01001539 bool supported = true;
1540
1541 std::array<DataType,3> supportedTypes = {
1542 DataType::Float32,
1543 DataType::QuantisedAsymm8,
1544 DataType::QuantisedSymm16
1545 };
1546
1547 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1548 "Reference subtraction: input 0 is not a supported type.");
1549
1550 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1551 "Reference subtraction: input 1 is not a supported type.");
1552
1553 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1554 "Reference subtraction: output is not a supported type.");
1555
1556 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1557 "Reference subtraction: input 0 and Input 1 types are mismatched");
1558
1559 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1560 "Reference subtraction: input and output types are mismatched");
1561
1562 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1563 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1564
1565 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001566}
1567
Matteo Martincighab9e5252019-06-13 17:27:46 +01001568bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1569 const TensorInfo& alpha,
1570 const TensorInfo& output,
1571 Optional<std::string&> reasonIfUnsupported) const
1572{
1573 bool supported = true;
1574
1575 std::array<DataType, 3> supportedTypes
1576 {
1577 DataType::Float32,
1578 DataType::QuantisedAsymm8,
1579 DataType::QuantisedSymm16
1580 };
1581
1582 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1583 "PReLU: input is not a supported type.");
1584
1585 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1586 "PReLU: alpha is not a supported type.");
1587
1588 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1589 "PReLU: output is not a supported type.");
1590
1591 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1592 "PReLU: input, alpha and output types are mismatched");
1593
1594 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1595 "PReLU: shapes are not suitable for implicit broadcast");
1596
1597 return supported;
1598}
1599
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001600bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1601 const TensorInfo& output,
1602 const TransposeConvolution2dDescriptor& descriptor,
1603 const TensorInfo& weights,
1604 const Optional<TensorInfo>& biases,
1605 Optional<std::string&> reasonIfUnsupported) const
1606{
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001607 bool supported = true;
1608
1609 std::array<DataType,3> supportedTypes =
1610 {
1611 DataType::Float32,
1612 DataType::QuantisedAsymm8,
1613 DataType::QuantisedSymm16
1614 };
1615
1616 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1617 "Reference TransposeConvolution2d: input is not a supported type.");
1618
1619 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1620 "Reference TransposeConvolution2d: output is not a supported type.");
1621
1622 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1623 "Reference TransposeConvolution2d: weights is not a supported type.");
1624
1625 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1626 "Reference TransposeConvolution2d: input and output types mismatched.");
1627
1628 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1629 "Reference TransposeConvolution2d: input and weights types mismatched.");
1630
1631 if (biases.has_value())
1632 {
Mike Kelly4992c342019-08-14 11:33:11 +01001633 std::array<DataType,2> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001634 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001635 DataType::Float32,
1636 DataType::Signed32
1637 };
1638 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1639 "Reference TransposeConvolution2d: biases is not a supported type.");
1640 }
1641
1642 return supported;
1643}
1644
arovir011c7c81b2018-10-08 11:34:28 +01001645} // namespace armnn