blob: 9be1ed6d749b7026a46ead8315d27102ea0cae77 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010016
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
20#include <algorithm>
21#include <array>
22
telsoa014fcda012018-03-09 14:13:49 +000023using namespace boost;
24
25namespace armnn
26{
27
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010028namespace
29{
30
31template<typename Float32Func, typename Uint8Func, typename ... Params>
32bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33 DataType dataType,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
36 Params&&... params)
37{
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39 dataType,
40 &FalseFunc<Params...>,
41 floatFuncPtr,
42 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000043 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000044 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010045 std::forward<Params>(params)...);
46}
47
48} // anonymous namespace
49
Derek Lamberti50db4e82019-03-13 14:16:15 +000050
51namespace
52{
53template<typename F>
54bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
55{
56 bool supported = rule();
57 if (!supported && reason)
58 {
59 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
60 }
61 return supported;
62}
63
64struct Rule
65{
66 bool operator()() const
67 {
68 return m_Res;
69 }
70
71 bool m_Res = true;
72};
73
Derek Lamberti2a434a82019-03-20 13:07:57 +000074template<typename T>
75bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000076{
77 return true;
78}
79
80template<typename T, typename... Rest>
81bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
82{
83 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
84
Derek Lamberti2a434a82019-03-20 13:07:57 +000085 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000086}
87
88struct TypesAreEqual : public Rule
89{
90 template<typename ... Ts>
91 TypesAreEqual(const Ts&... ts)
92 {
93 m_Res = AllTypesAreEqualImpl(ts...);
94 }
95};
96
97struct QuantizationParametersAreEqual : public Rule
98{
99 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
100 {
101 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
102 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
103 }
104};
105
106struct TypeAnyOf : public Rule
107{
108 template<typename Container>
109 TypeAnyOf(const TensorInfo& info, const Container& c)
110 {
111 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100112 {
113 return dt == info.GetDataType();
114 });
115 }
116};
117
118struct BiasAndWeightsTypesMatch : public Rule
119{
120 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
121 {
122 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
123 }
124};
125
126struct BiasAndWeightsTypesCompatible : public Rule
127{
128 template<typename Container>
129 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
130 {
131 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
132 {
133 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
134 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000135 }
136};
137
138struct ShapesAreSameRank : public Rule
139{
140 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
141 {
142 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
143 }
144};
145
Derek Lamberti5f400d62019-03-25 15:41:58 +0000146struct ShapesAreSameTotalSize : public Rule
147{
148 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
149 {
150 m_Res = info0.GetNumElements() == info1.GetNumElements();
151 }
152};
153
Derek Lamberti50db4e82019-03-13 14:16:15 +0000154struct ShapesAreBroadcastCompatible : public Rule
155{
156 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
157 {
158 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
159 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
160 return sizeIn;
161 }
162
163 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
164 {
165 const TensorShape& shape0 = in0.GetShape();
166 const TensorShape& shape1 = in1.GetShape();
167 const TensorShape& outShape = out.GetShape();
168
169 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
170 {
171 unsigned int sizeOut = outShape[i];
172 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
173 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
174
175 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
176 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
177 }
178 }
179};
180} // namespace
181
182
arovir011c7c81b2018-10-08 11:34:28 +0100183bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
184 const TensorInfo& output,
185 const ActivationDescriptor& descriptor,
186 Optional<std::string&> reasonIfUnsupported) const
187{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000188 bool supported = true;
189
190 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100191 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000192 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100193 DataType::QuantisedAsymm8,
194 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000195 };
196
197 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
198 "Reference activation: input type not supported.");
199
200 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
201 "Reference activation: output type not supported.");
202
203 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
204 "Reference activation: input and output types mismatched.");
205
206 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
207 "Reference activation: input and output shapes are of different rank.");
208
209
210 struct ActivationFunctionSupported : public Rule
211 {
212 ActivationFunctionSupported(const ActivationDescriptor& desc)
213 {
214 switch(desc.m_Function)
215 {
216 case ActivationFunction::Abs:
217 case ActivationFunction::BoundedReLu:
218 case ActivationFunction::LeakyReLu:
219 case ActivationFunction::Linear:
220 case ActivationFunction::ReLu:
221 case ActivationFunction::Sigmoid:
222 case ActivationFunction::SoftReLu:
223 case ActivationFunction::Sqrt:
224 case ActivationFunction::Square:
225 case ActivationFunction::TanH:
226 {
227 m_Res = true;
228 break;
229 }
230 default:
231 {
232 m_Res = false;
233 break;
234 }
235 }
236 }
237 };
238
239 // Function is supported
240 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
241 "Reference activation: function not supported.");
242
243 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100244}
245
246bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
247 const TensorInfo& input1,
248 const TensorInfo& output,
249 Optional<std::string&> reasonIfUnsupported) const
250{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000251 bool supported = true;
252
Sadik Armagan2999a022019-04-09 14:20:12 +0100253 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000254 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100255 DataType::QuantisedAsymm8,
256 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000257 };
258
259 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
260 "Reference addition: input 0 is not a supported type.");
261
262 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
263 "Reference addition: input 1 is not a supported type.");
264
265 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
266 "Reference addition: output is not a supported type.");
267
268 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
269 "Reference addition: input 0 and Input 1 types are mismatched");
270
271 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
272 "Reference addition: input and output types are mismatched");
273
274 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
275 "Reference addition: shapes are not suitable for implicit broadcast.");
276
277 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100278}
279
280bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const TensorInfo& mean,
283 const TensorInfo& var,
284 const TensorInfo& beta,
285 const TensorInfo& gamma,
286 const BatchNormalizationDescriptor& descriptor,
287 Optional<std::string&> reasonIfUnsupported) const
288{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100289 ignore_unused(output);
290 ignore_unused(mean);
291 ignore_unused(var);
292 ignore_unused(beta);
293 ignore_unused(gamma);
294 ignore_unused(descriptor);
295 return IsSupportedForDataTypeRef(reasonIfUnsupported,
296 input.GetDataType(),
297 &TrueFunc<>,
298 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100299}
300
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000301bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
302 const TensorInfo& output,
303 const BatchToSpaceNdDescriptor& descriptor,
304 Optional<std::string&> reasonIfUnsupported) const
305{
306 ignore_unused(descriptor);
307 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
308 input.GetDataType(),
309 &TrueFunc<>,
310 &TrueFunc<>) &&
311 IsSupportedForDataTypeRef(reasonIfUnsupported,
312 output.GetDataType(),
313 &TrueFunc<>,
314 &TrueFunc<>));
315}
316
Jim Flynn906f9462019-05-10 13:55:21 +0100317bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
318 const TensorInfo& output,
319 const OriginsDescriptor& descriptor,
320 Optional<std::string&> reasonIfUnsupported) const
321{
322 ARMNN_NO_DEPRECATE_WARN_BEGIN
323 return IsMergerSupported(inputs, output, descriptor, reasonIfUnsupported);
324 ARMNN_NO_DEPRECATE_WARN_END
325}
326
arovir011c7c81b2018-10-08 11:34:28 +0100327bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
328 Optional<std::string&> reasonIfUnsupported) const
329{
Nina Drozd58ef2c62019-05-16 12:09:18 +0100330 std::array<DataType,4> supportedTypes = {
331 DataType::Float32,
332 DataType::Signed32,
333 DataType::QuantisedAsymm8,
334 DataType::QuantisedSymm16
335 };
336
337 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
338 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100339}
340
341bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
342 const TensorInfo& output,
343 Optional<std::string&> reasonIfUnsupported) const
344{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100345 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
346 input.GetDataType(),
347 &TrueFunc<>,
348 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000349 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000350 &FalseFuncI32<>,
351 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100352 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
353 output.GetDataType(),
354 &FalseOutputFuncF16<>,
355 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000356 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000357 &FalseFuncI32<>,
358 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100359}
360
361bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
362 const TensorInfo& output,
363 Optional<std::string&> reasonIfUnsupported) const
364{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100365 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
366 input.GetDataType(),
367 &FalseInputFuncF16<>,
368 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000369 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000370 &FalseFuncI32<>,
371 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100372 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
373 output.GetDataType(),
374 &TrueFunc<>,
375 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000376 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000377 &FalseFuncI32<>,
378 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100379}
380
381bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
382 const TensorInfo& output,
383 const Convolution2dDescriptor& descriptor,
384 const TensorInfo& weights,
385 const Optional<TensorInfo>& biases,
386 Optional<std::string&> reasonIfUnsupported) const
387{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100388 bool supported = true;
389
390 // Define supported types.
391 std::array<DataType,3> supportedTypes = {
392 DataType::Float32,
393 DataType::QuantisedAsymm8,
394 DataType::QuantisedSymm16
395 };
396
397 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
398 "Reference addition: input is not a supported type.");
399
400 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
401 "Reference addition: output is not a supported type.");
402
403 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
404 "Reference addition: weights is not a supported type.");
405
406 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
407 "Reference activation: input and output types mismatched.");
408
409 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
410 "Reference activation: input and weights types mismatched.");
411
412 if (biases.has_value())
413 {
414 std::array<DataType,3> biasesSupportedTypes = {
415 DataType::Float32,
416 DataType::Signed32
417 };
418 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
419 "Reference addition: biases is not a supported type.");
420 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100421 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100422
423 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100424}
425
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000426bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
427 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000428 Optional<std::string&> reasonIfUnsupported) const
429{
430 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000431 return IsSupportedForDataTypeRef(reasonIfUnsupported,
432 input.GetDataType(),
433 &TrueFunc<>,
434 &TrueFunc<>);
435}
436
arovir011c7c81b2018-10-08 11:34:28 +0100437bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
438 const TensorInfo& output,
439 const DepthwiseConvolution2dDescriptor& descriptor,
440 const TensorInfo& weights,
441 const Optional<TensorInfo>& biases,
442 Optional<std::string&> reasonIfUnsupported) const
443{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100444 ignore_unused(output);
445 ignore_unused(descriptor);
446 ignore_unused(weights);
447 ignore_unused(biases);
448 return IsSupportedForDataTypeRef(reasonIfUnsupported,
449 input.GetDataType(),
450 &TrueFunc<>,
451 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100452}
453
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000454bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
455 const TensorInfo& output,
456 Optional<std::string&> reasonIfUnsupported) const
457{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100458 bool supported = true;
459
460 std::array<DataType,2> supportedInputTypes = {
461 DataType::QuantisedAsymm8,
462 DataType::QuantisedSymm16
463 };
464
465 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
466 "Reference dequantize: input type not supported.");
467
468 std::array<DataType,2> supportedOutputTypes = {
469 DataType::Float32,
470 };
471
472 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
473 "Reference dequantize: output type not supported.");
474
475 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
476 "Reference dequantize: input and output shapes have different num total elements.");
477
478 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000479}
480
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000481bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
482 const armnn::TensorInfo& input1,
483 const armnn::DetectionPostProcessDescriptor& descriptor,
484 armnn::Optional<std::string&> reasonIfUnsupported) const
485{
486 ignore_unused(input1);
487 return IsSupportedForDataTypeRef(reasonIfUnsupported,
488 input0.GetDataType(),
489 &TrueFunc<>,
490 &TrueFunc<>);
491}
492
Pablo Tellof0bd6832019-04-26 17:58:13 +0100493bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
494 const TensorInfo& output,
495 const DepthwiseConvolution2dDescriptor& descriptor,
496 const TensorInfo& weights,
497 const Optional<TensorInfo>& biases,
498 Optional<std::string&> reasonIfUnsupported) const
499{
500 if (descriptor.m_DilationY == 1 && descriptor.m_DilationY == 1)
501 {
502 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
503 }
504 else
505 {
506 if (reasonIfUnsupported)
507 {
508 reasonIfUnsupported.value() = "Reference Depthwise Convolution: Dilation parameters must be 1";
509 }
510 return false;
511 }
512}
513
514
515 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100516 const TensorInfo& input1,
517 const TensorInfo& output,
518 Optional<std::string&> reasonIfUnsupported) const
519{
Sadik Armagan2999a022019-04-09 14:20:12 +0100520 bool supported = true;
521
522 std::array<DataType,3> supportedTypes = {
523 DataType::Float32,
524 DataType::QuantisedAsymm8,
525 DataType::QuantisedSymm16
526 };
527
528 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
529 "Reference division: input 0 is not a supported type.");
530
531 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
532 "Reference division: input 1 is not a supported type.");
533
534 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
535 "Reference division: output is not a supported type.");
536
537 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
538 "Reference division: input 0 and Input 1 types are mismatched");
539
540 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
541 "Reference division: input and output types are mismatched");
542
543 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
544 "Reference division: shapes are not suitable for implicit broadcast.");
545
546 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100547}
548
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000549bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
550 const TensorInfo& input1,
551 const TensorInfo& output,
552 Optional<std::string&> reasonIfUnsupported) const
553{
554 ignore_unused(input0);
555 ignore_unused(input1);
556 ignore_unused(output);
557 ignore_unused(reasonIfUnsupported);
558 return IsSupportedForDataTypeRef(reasonIfUnsupported,
559 input0.GetDataType(),
560 &TrueFunc<>,
561 &TrueFunc<>);
562}
563
arovir011c7c81b2018-10-08 11:34:28 +0100564bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
565 const FakeQuantizationDescriptor& descriptor,
566 Optional<std::string&> reasonIfUnsupported) const
567{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100568 ignore_unused(descriptor);
569 return IsSupportedForDataTypeRef(reasonIfUnsupported,
570 input.GetDataType(),
571 &TrueFunc<>,
572 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100573}
574
575bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
576 const TensorInfo& output,
577 Optional<std::string&> reasonIfUnsupported) const
578{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100579 ignore_unused(output);
580 return IsSupportedForDataTypeRef(reasonIfUnsupported,
581 input.GetDataType(),
582 &TrueFunc<>,
583 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100584}
585
586bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
587 const TensorInfo& output,
588 const TensorInfo& weights,
589 const TensorInfo& biases,
590 const FullyConnectedDescriptor& descriptor,
591 Optional<std::string&> reasonIfUnsupported) const
592{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100593 bool supported = true;
594
595 // Define supported types.
596 std::array<DataType,3> supportedTypes =
597 {
598 DataType::Float32,
599 DataType::QuantisedAsymm8,
600 DataType::QuantisedSymm16
601 };
602
603 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
604 "Reference Fully Connected: input type not supported.");
605
606 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
607 "Reference Fully Connected: output type not supported.");
608
609 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
610 "Reference Fully Connected: input and output types mismatched.");
611
612 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
613 "Reference Fully Connected: weights type not supported.");
614
615 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
616 "Reference Fully Connected: input and weight types mismatched.");
617
618 if (descriptor.m_BiasEnabled)
619 {
620 // Defined supported types for bias
621 std::array<DataType, 2>
622 supportedBiasTypes =
623 {
624 DataType::Float32,
625 DataType::Signed32
626 };
627
628 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
629 "Reference Fully Connected: bias type not supported.");
630
631 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
632 "Reference Fully Connected: bias and weight types mismatch.");
633
634 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
635 "Reference Fully Connected: bias type inferred from weights is incompatible.");
636
637 }
638
639 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100640}
641
narpra014951d842019-01-18 16:53:53 +0000642bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
643 const armnn::TensorInfo& input1,
644 const armnn::TensorInfo& output,
645 armnn::Optional<std::string&> reasonIfUnsupported) const
646{
647 ignore_unused(input1);
648 ignore_unused(output);
649 return IsSupportedForDataTypeRef(reasonIfUnsupported,
650 input0.GetDataType(),
651 &TrueFunc<>,
652 &TrueFunc<>);
653}
654
FrancisMurtagh878f0232018-12-19 10:56:15 +0000655bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
656 const TensorInfo& input1,
657 const TensorInfo& output,
658 Optional<std::string&> reasonIfUnsupported) const
659{
660 ignore_unused(input0);
661 ignore_unused(input1);
662 ignore_unused(output);
663 ignore_unused(reasonIfUnsupported);
664 return IsSupportedForDataTypeRef(reasonIfUnsupported,
665 input0.GetDataType(),
666 &TrueFunc<>,
667 &TrueFunc<>);
668}
669
arovir011c7c81b2018-10-08 11:34:28 +0100670bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
671 Optional<std::string&> reasonIfUnsupported) const
672{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100673 return IsSupportedForDataTypeRef(reasonIfUnsupported,
674 input.GetDataType(),
675 &TrueFunc<>,
676 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100677}
678
679bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
680 const TensorInfo& output,
681 const L2NormalizationDescriptor& descriptor,
682 Optional<std::string&> reasonIfUnsupported) const
683{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100684 ignore_unused(output);
685 ignore_unused(descriptor);
686 return IsSupportedForDataTypeRef(reasonIfUnsupported,
687 input.GetDataType(),
688 &TrueFunc<>,
689 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100690}
691
692bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
693 const TensorInfo& outputStateIn,
694 const TensorInfo& cellStateIn,
695 const TensorInfo& scratchBuffer,
696 const TensorInfo& outputStateOut,
697 const TensorInfo& cellStateOut,
698 const TensorInfo& output,
699 const LstmDescriptor& descriptor,
700 const TensorInfo& inputToForgetWeights,
701 const TensorInfo& inputToCellWeights,
702 const TensorInfo& inputToOutputWeights,
703 const TensorInfo& recurrentToForgetWeights,
704 const TensorInfo& recurrentToCellWeights,
705 const TensorInfo& recurrentToOutputWeights,
706 const TensorInfo& forgetGateBias,
707 const TensorInfo& cellBias,
708 const TensorInfo& outputGateBias,
709 const TensorInfo* inputToInputWeights,
710 const TensorInfo* recurrentToInputWeights,
711 const TensorInfo* cellToInputWeights,
712 const TensorInfo* inputGateBias,
713 const TensorInfo* projectionWeights,
714 const TensorInfo* projectionBias,
715 const TensorInfo* cellToForgetWeights,
716 const TensorInfo* cellToOutputWeights,
717 Optional<std::string&> reasonIfUnsupported) const
718{
telsoa01c577f2c2018-08-31 09:22:23 +0100719 ignore_unused(descriptor);
720 ignore_unused(inputToForgetWeights);
721 ignore_unused(inputToCellWeights);
722 ignore_unused(inputToOutputWeights);
723 ignore_unused(recurrentToForgetWeights);
724 ignore_unused(recurrentToCellWeights);
725 ignore_unused(recurrentToOutputWeights);
726 ignore_unused(forgetGateBias);
727 ignore_unused(cellBias);
728 ignore_unused(outputGateBias);
729 ignore_unused(inputToInputWeights);
730 ignore_unused(recurrentToInputWeights);
731 ignore_unused(cellToInputWeights);
732 ignore_unused(inputGateBias);
733 ignore_unused(projectionWeights);
734 ignore_unused(projectionBias);
735 ignore_unused(cellToForgetWeights);
736 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100737
738 bool supported = true;
739
740 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100741 DataType::Float32,
742 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100743 };
744
745 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
746 "Reference Lstm: input is not a supported type.");
747
748 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
749 "Reference Lstm: input and outputStateIn types are mismatched");
750
751 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
752 "Reference Lstm: input and cellStateIn types are mismatched");
753
754 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
755 "Reference Lstm: input and scratchBuffer types are mismatched");
756
757 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
758 "Reference Lstm: input and outputStateOut types are mismatched");
759
760 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
761 "Reference Lstm: input and cellStateOut types are mismatched");
762
763 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
764 "Reference Lstm: input and output types are mismatched");
765
766 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100767}
768
saoste012df12b32018-11-28 16:57:20 +0000769bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
770 const TensorInfo& input1,
771 const TensorInfo& output,
772 Optional<std::string&> reasonIfUnsupported) const
773{
Sadik Armagan2999a022019-04-09 14:20:12 +0100774 bool supported = true;
775
776 std::array<DataType,3> supportedTypes = {
777 DataType::Float32,
778 DataType::QuantisedAsymm8,
779 DataType::QuantisedSymm16
780 };
781
782 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
783 "Reference maximum: input 0 is not a supported type.");
784
785 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
786 "Reference maximum: input 1 is not a supported type.");
787
788 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
789 "Reference maximum: output is not a supported type.");
790
791 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
792 "Reference maximum: input 0 and Input 1 types are mismatched");
793
794 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
795 "Reference maximum: input and output types are mismatched");
796
797 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
798 "Reference maximum: shapes are not suitable for implicit broadcast.");
799
800 return supported;
saoste012df12b32018-11-28 16:57:20 +0000801}
802
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100803bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
804 const TensorInfo& output,
805 const MeanDescriptor& descriptor,
806 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100807{
narpra011e4c31d2018-09-28 11:07:51 +0100808 ignore_unused(output);
809 ignore_unused(descriptor);
810 return IsSupportedForDataTypeRef(reasonIfUnsupported,
811 input.GetDataType(),
812 &TrueFunc<>,
813 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100814}
815
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100816bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000817 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100818 const OriginsDescriptor& descriptor,
819 Optional<std::string&> reasonIfUnsupported) const
820{
821 ignore_unused(descriptor);
Jim Flynncbb66aa2019-05-15 13:03:54 +0100822
823 bool supported = true;
824 std::array<DataType,3> supportedTypes =
825 {
826 DataType::Float32,
827 DataType::QuantisedAsymm8,
828 DataType::QuantisedSymm16
829 };
830
831 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
832 "Reference concatenation: output type not supported");
833 for (const TensorInfo* input : inputs)
834 {
835 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
836 "Reference concatenation: input type not supported");
837
838 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
839 "Reference concatenation: input and output types mismatched.");
840 }
841
842 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100843}
844
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000845bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
846 const TensorInfo &output,
847 Optional<std::string &> reasonIfUnsupported) const
848{
849 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000850 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
851 input.GetDataType(),
852 &TrueFunc<>,
853 &TrueFunc<>,
854 &TrueFunc<>,
855 &FalseFuncI32<>,
856 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000857}
858
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000859bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
860 const TensorInfo& input1,
861 const TensorInfo& output,
862 Optional<std::string&> reasonIfUnsupported) const
863{
Sadik Armagan2999a022019-04-09 14:20:12 +0100864 bool supported = true;
865
866 std::array<DataType,3> supportedTypes = {
867 DataType::Float32,
868 DataType::QuantisedAsymm8,
869 DataType::QuantisedSymm16
870 };
871
872 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
873 "Reference minimum: input 0 is not a supported type.");
874
875 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
876 "Reference minimum: input 1 is not a supported type.");
877
878 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
879 "Reference minimum: output is not a supported type.");
880
881 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
882 "Reference minimum: input 0 and Input 1 types are mismatched");
883
884 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
885 "Reference minimum: input and output types are mismatched");
886
887 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
888 "Reference minimum: shapes are not suitable for implicit broadcast.");
889
890 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000891}
892
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100893bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
894 const TensorInfo& input1,
895 const TensorInfo& output,
896 Optional<std::string&> reasonIfUnsupported) const
897{
Sadik Armagan2999a022019-04-09 14:20:12 +0100898 bool supported = true;
899
900 std::array<DataType,3> supportedTypes = {
901 DataType::Float32,
902 DataType::QuantisedAsymm8,
903 DataType::QuantisedSymm16
904 };
905
906 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
907 "Reference multiplication: input 0 is not a supported type.");
908
909 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
910 "Reference multiplication: input 1 is not a supported type.");
911
912 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
913 "Reference multiplication: output is not a supported type.");
914
915 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
916 "Reference multiplication: input 0 and Input 1 types are mismatched");
917
918 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
919 "Reference multiplication: input and output types are mismatched");
920
921 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
922 "Reference multiplication: shapes are not suitable for implicit broadcast.");
923
924 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100925}
926
927bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
928 const TensorInfo& output,
929 const NormalizationDescriptor& descriptor,
930 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100931{
932 ignore_unused(output);
933 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100934 return IsSupportedForDataTypeRef(reasonIfUnsupported,
935 input.GetDataType(),
936 &TrueFunc<>,
937 &FalseFuncU8<>);
938}
939
940bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
941 Optional<std::string&> reasonIfUnsupported) const
942{
kevmay012b4d88e2019-01-24 14:05:09 +0000943 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
944 output.GetDataType(),
945 &TrueFunc<>,
946 &TrueFunc<>,
947 &TrueFunc<>,
948 &FalseFuncI32<>,
949 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100950}
951
952bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
953 const TensorInfo& output,
954 const PadDescriptor& descriptor,
955 Optional<std::string&> reasonIfUnsupported) const
956{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100957 ignore_unused(output);
958 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000959 return IsSupportedForDataTypeRef(reasonIfUnsupported,
960 input.GetDataType(),
961 &TrueFunc<>,
962 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100963}
964
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100965bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
966 const TensorInfo& output,
967 const PermuteDescriptor& descriptor,
968 Optional<std::string&> reasonIfUnsupported) const
969{
970 ignore_unused(output);
971 ignore_unused(descriptor);
972 return IsSupportedForDataTypeRef(reasonIfUnsupported,
973 input.GetDataType(),
974 &TrueFunc<>,
975 &TrueFunc<>);
976}
977
978bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
979 const TensorInfo& output,
980 const Pooling2dDescriptor& descriptor,
981 Optional<std::string&> reasonIfUnsupported) const
982{
983 ignore_unused(output);
984 ignore_unused(descriptor);
985 return IsSupportedForDataTypeRef(reasonIfUnsupported,
986 input.GetDataType(),
987 &TrueFunc<>,
988 &TrueFunc<>);
989}
990
Derek Lamberti5f400d62019-03-25 15:41:58 +0000991bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
992 const TensorInfo& output,
993 Optional<std::string&> reasonIfUnsupported) const
994{
995 bool supported = true;
996
997 // Define supported output types.
998 std::array<DataType,2> supportedInputTypes = {
999 DataType::Float32,
1000 };
1001
1002 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1003 "Reference quantize: input type not supported.");
1004
1005 // Define supported output types.
1006 std::array<DataType,2> supportedOutputTypes = {
1007 DataType::QuantisedAsymm8,
1008 DataType::QuantisedSymm16
1009 };
1010 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1011 "Reference quantize: output type not supported.");
1012
1013 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1014 "Reference quantize: input and output shapes have different num total elements.");
1015
1016 return supported;
1017}
1018
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001019bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001020 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001021 Optional<std::string&> reasonIfUnsupported) const
1022{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001023 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001024 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1025 input.GetDataType(),
1026 &TrueFunc<>,
1027 &TrueFunc<>);
1028}
1029
1030bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001031 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001032 Optional<std::string&> reasonIfUnsupported) const
1033{
Sadik Armaganc625f002018-12-17 11:32:16 +00001034 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001035 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1036 input.GetDataType(),
1037 &TrueFunc<>,
1038 &TrueFunc<>);
1039}
1040
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001041bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1042 const TensorInfo& output,
1043 Optional<std::string&> reasonIfUnsupported) const
1044{
1045 ignore_unused(output);
1046 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1047 input.GetDataType(),
1048 &TrueFunc<>,
1049 &FalseFuncU8<>);
1050}
1051
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001052bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1053 const TensorInfo& output,
1054 const SoftmaxDescriptor& descriptor,
1055 Optional<std::string&> reasonIfUnsupported) const
1056{
1057 ignore_unused(output);
1058 ignore_unused(descriptor);
1059 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1060 input.GetDataType(),
1061 &TrueFunc<>,
1062 &TrueFunc<>);
1063}
1064
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001065bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1066 const TensorInfo& output,
1067 const SpaceToBatchNdDescriptor& descriptor,
1068 Optional<std::string&> reasonIfUnsupported) const
1069{
1070 ignore_unused(output);
1071 ignore_unused(descriptor);
1072 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1073 input.GetDataType(),
1074 &TrueFunc<>,
1075 &TrueFunc<>);
1076}
1077
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001078bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1079 const ViewsDescriptor& descriptor,
1080 Optional<std::string&> reasonIfUnsupported) const
1081{
1082 ignore_unused(descriptor);
1083 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1084 input.GetDataType(),
1085 &TrueFunc<>,
1086 &TrueFunc<>);
1087}
1088
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001089bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1090 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1091 const ViewsDescriptor& descriptor,
1092 Optional<std::string&> reasonIfUnsupported) const
1093{
1094 ignore_unused(descriptor);
1095 ignore_unused(outputs);
1096 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1097 input.GetDataType(),
1098 &TrueFunc<>,
1099 &TrueFunc<>);
1100}
1101
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001102bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1103 const TensorInfo& output,
1104 const StridedSliceDescriptor& descriptor,
1105 Optional<std::string&> reasonIfUnsupported) const
1106{
1107 ignore_unused(output);
1108 ignore_unused(descriptor);
1109 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1110 input.GetDataType(),
1111 &TrueFunc<>,
1112 &TrueFunc<>);
1113}
1114
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001115bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1116 const TensorInfo& input1,
1117 const TensorInfo& output,
1118 Optional<std::string&> reasonIfUnsupported) const
1119{
Sadik Armagan2999a022019-04-09 14:20:12 +01001120 bool supported = true;
1121
1122 std::array<DataType,3> supportedTypes = {
1123 DataType::Float32,
1124 DataType::QuantisedAsymm8,
1125 DataType::QuantisedSymm16
1126 };
1127
1128 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1129 "Reference subtraction: input 0 is not a supported type.");
1130
1131 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1132 "Reference subtraction: input 1 is not a supported type.");
1133
1134 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1135 "Reference subtraction: output is not a supported type.");
1136
1137 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1138 "Reference subtraction: input 0 and Input 1 types are mismatched");
1139
1140 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1141 "Reference subtraction: input and output types are mismatched");
1142
1143 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1144 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1145
1146 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001147}
1148
arovir011c7c81b2018-10-08 11:34:28 +01001149} // namespace armnn