blob: b6da628be3fe677403fab31aa0c53ce648adb8b4 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010015
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
19#include <algorithm>
20#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
Derek Lamberti50db4e82019-03-13 14:16:15 +000049
50namespace
51{
52template<typename F>
53bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
54{
55 bool supported = rule();
56 if (!supported && reason)
57 {
58 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
59 }
60 return supported;
61}
62
63struct Rule
64{
65 bool operator()() const
66 {
67 return m_Res;
68 }
69
70 bool m_Res = true;
71};
72
Derek Lamberti2a434a82019-03-20 13:07:57 +000073template<typename T>
74bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000075{
76 return true;
77}
78
79template<typename T, typename... Rest>
80bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
81{
82 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
83
Derek Lamberti2a434a82019-03-20 13:07:57 +000084 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000085}
86
87struct TypesAreEqual : public Rule
88{
89 template<typename ... Ts>
90 TypesAreEqual(const Ts&... ts)
91 {
92 m_Res = AllTypesAreEqualImpl(ts...);
93 }
94};
95
96struct QuantizationParametersAreEqual : public Rule
97{
98 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
99 {
100 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
101 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
102 }
103};
104
105struct TypeAnyOf : public Rule
106{
107 template<typename Container>
108 TypeAnyOf(const TensorInfo& info, const Container& c)
109 {
110 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
111 {
112 return dt == info.GetDataType();
113 });
114 }
115};
116
117struct ShapesAreSameRank : public Rule
118{
119 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
120 {
121 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
122 }
123};
124
Derek Lamberti5f400d62019-03-25 15:41:58 +0000125struct ShapesAreSameTotalSize : public Rule
126{
127 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
128 {
129 m_Res = info0.GetNumElements() == info1.GetNumElements();
130 }
131};
132
Derek Lamberti50db4e82019-03-13 14:16:15 +0000133struct ShapesAreBroadcastCompatible : public Rule
134{
135 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
136 {
137 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
138 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
139 return sizeIn;
140 }
141
142 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
143 {
144 const TensorShape& shape0 = in0.GetShape();
145 const TensorShape& shape1 = in1.GetShape();
146 const TensorShape& outShape = out.GetShape();
147
148 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
149 {
150 unsigned int sizeOut = outShape[i];
151 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
152 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
153
154 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
155 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
156 }
157 }
158};
159} // namespace
160
161
arovir011c7c81b2018-10-08 11:34:28 +0100162bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
163 const TensorInfo& output,
164 const ActivationDescriptor& descriptor,
165 Optional<std::string&> reasonIfUnsupported) const
166{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000167 bool supported = true;
168
169 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100170 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000171 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100172 DataType::QuantisedAsymm8,
173 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000174 };
175
176 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
177 "Reference activation: input type not supported.");
178
179 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
180 "Reference activation: output type not supported.");
181
182 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
183 "Reference activation: input and output types mismatched.");
184
185 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
186 "Reference activation: input and output shapes are of different rank.");
187
188
189 struct ActivationFunctionSupported : public Rule
190 {
191 ActivationFunctionSupported(const ActivationDescriptor& desc)
192 {
193 switch(desc.m_Function)
194 {
195 case ActivationFunction::Abs:
196 case ActivationFunction::BoundedReLu:
197 case ActivationFunction::LeakyReLu:
198 case ActivationFunction::Linear:
199 case ActivationFunction::ReLu:
200 case ActivationFunction::Sigmoid:
201 case ActivationFunction::SoftReLu:
202 case ActivationFunction::Sqrt:
203 case ActivationFunction::Square:
204 case ActivationFunction::TanH:
205 {
206 m_Res = true;
207 break;
208 }
209 default:
210 {
211 m_Res = false;
212 break;
213 }
214 }
215 }
216 };
217
218 // Function is supported
219 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
220 "Reference activation: function not supported.");
221
222 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100223}
224
225bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
226 const TensorInfo& input1,
227 const TensorInfo& output,
228 Optional<std::string&> reasonIfUnsupported) const
229{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000230 bool supported = true;
231
Sadik Armagan2999a022019-04-09 14:20:12 +0100232 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000233 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100234 DataType::QuantisedAsymm8,
235 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000236 };
237
238 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
239 "Reference addition: input 0 is not a supported type.");
240
241 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
242 "Reference addition: input 1 is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
245 "Reference addition: output is not a supported type.");
246
247 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
248 "Reference addition: input 0 and Input 1 types are mismatched");
249
250 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
251 "Reference addition: input and output types are mismatched");
252
253 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
254 "Reference addition: shapes are not suitable for implicit broadcast.");
255
256 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100257}
258
259bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
260 const TensorInfo& output,
261 const TensorInfo& mean,
262 const TensorInfo& var,
263 const TensorInfo& beta,
264 const TensorInfo& gamma,
265 const BatchNormalizationDescriptor& descriptor,
266 Optional<std::string&> reasonIfUnsupported) const
267{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100268 ignore_unused(output);
269 ignore_unused(mean);
270 ignore_unused(var);
271 ignore_unused(beta);
272 ignore_unused(gamma);
273 ignore_unused(descriptor);
274 return IsSupportedForDataTypeRef(reasonIfUnsupported,
275 input.GetDataType(),
276 &TrueFunc<>,
277 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100278}
279
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000280bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const BatchToSpaceNdDescriptor& descriptor,
283 Optional<std::string&> reasonIfUnsupported) const
284{
285 ignore_unused(descriptor);
286 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
287 input.GetDataType(),
288 &TrueFunc<>,
289 &TrueFunc<>) &&
290 IsSupportedForDataTypeRef(reasonIfUnsupported,
291 output.GetDataType(),
292 &TrueFunc<>,
293 &TrueFunc<>));
294}
295
Jim Flynn906f9462019-05-10 13:55:21 +0100296bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
297 const TensorInfo& output,
298 const OriginsDescriptor& descriptor,
299 Optional<std::string&> reasonIfUnsupported) const
300{
301 ARMNN_NO_DEPRECATE_WARN_BEGIN
302 return IsMergerSupported(inputs, output, descriptor, reasonIfUnsupported);
303 ARMNN_NO_DEPRECATE_WARN_END
304}
305
arovir011c7c81b2018-10-08 11:34:28 +0100306bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
307 Optional<std::string&> reasonIfUnsupported) const
308{
Nina Drozd58ef2c62019-05-16 12:09:18 +0100309 std::array<DataType,4> supportedTypes = {
310 DataType::Float32,
311 DataType::Signed32,
312 DataType::QuantisedAsymm8,
313 DataType::QuantisedSymm16
314 };
315
316 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
317 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100318}
319
320bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
321 const TensorInfo& output,
322 Optional<std::string&> reasonIfUnsupported) const
323{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100324 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
325 input.GetDataType(),
326 &TrueFunc<>,
327 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000328 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000329 &FalseFuncI32<>,
330 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100331 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
332 output.GetDataType(),
333 &FalseOutputFuncF16<>,
334 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000335 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000336 &FalseFuncI32<>,
337 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100338}
339
340bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
341 const TensorInfo& output,
342 Optional<std::string&> reasonIfUnsupported) const
343{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100344 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
345 input.GetDataType(),
346 &FalseInputFuncF16<>,
347 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000348 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000349 &FalseFuncI32<>,
350 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100351 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
352 output.GetDataType(),
353 &TrueFunc<>,
354 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000355 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000356 &FalseFuncI32<>,
357 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100358}
359
360bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
361 const TensorInfo& output,
362 const Convolution2dDescriptor& descriptor,
363 const TensorInfo& weights,
364 const Optional<TensorInfo>& biases,
365 Optional<std::string&> reasonIfUnsupported) const
366{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100367 bool supported = true;
368
369 // Define supported types.
370 std::array<DataType,3> supportedTypes = {
371 DataType::Float32,
372 DataType::QuantisedAsymm8,
373 DataType::QuantisedSymm16
374 };
375
376 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
377 "Reference addition: input is not a supported type.");
378
379 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
380 "Reference addition: output is not a supported type.");
381
382 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
383 "Reference addition: weights is not a supported type.");
384
385 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
386 "Reference activation: input and output types mismatched.");
387
388 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
389 "Reference activation: input and weights types mismatched.");
390
391 if (biases.has_value())
392 {
393 std::array<DataType,3> biasesSupportedTypes = {
394 DataType::Float32,
395 DataType::Signed32
396 };
397 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
398 "Reference addition: biases is not a supported type.");
399 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100400 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100401
402 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100403}
404
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000405bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
406 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000407 Optional<std::string&> reasonIfUnsupported) const
408{
409 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000410 return IsSupportedForDataTypeRef(reasonIfUnsupported,
411 input.GetDataType(),
412 &TrueFunc<>,
413 &TrueFunc<>);
414}
415
arovir011c7c81b2018-10-08 11:34:28 +0100416bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
417 const TensorInfo& output,
418 const DepthwiseConvolution2dDescriptor& descriptor,
419 const TensorInfo& weights,
420 const Optional<TensorInfo>& biases,
421 Optional<std::string&> reasonIfUnsupported) const
422{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100423 ignore_unused(output);
424 ignore_unused(descriptor);
425 ignore_unused(weights);
426 ignore_unused(biases);
427 return IsSupportedForDataTypeRef(reasonIfUnsupported,
428 input.GetDataType(),
429 &TrueFunc<>,
430 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100431}
432
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000433bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
434 const TensorInfo& output,
435 Optional<std::string&> reasonIfUnsupported) const
436{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100437 bool supported = true;
438
439 std::array<DataType,2> supportedInputTypes = {
440 DataType::QuantisedAsymm8,
441 DataType::QuantisedSymm16
442 };
443
444 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
445 "Reference dequantize: input type not supported.");
446
447 std::array<DataType,2> supportedOutputTypes = {
448 DataType::Float32,
449 };
450
451 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
452 "Reference dequantize: output type not supported.");
453
454 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
455 "Reference dequantize: input and output shapes have different num total elements.");
456
457 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000458}
459
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000460bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
461 const armnn::TensorInfo& input1,
462 const armnn::DetectionPostProcessDescriptor& descriptor,
463 armnn::Optional<std::string&> reasonIfUnsupported) const
464{
465 ignore_unused(input1);
466 return IsSupportedForDataTypeRef(reasonIfUnsupported,
467 input0.GetDataType(),
468 &TrueFunc<>,
469 &TrueFunc<>);
470}
471
Pablo Tellof0bd6832019-04-26 17:58:13 +0100472bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
473 const TensorInfo& output,
474 const DepthwiseConvolution2dDescriptor& descriptor,
475 const TensorInfo& weights,
476 const Optional<TensorInfo>& biases,
477 Optional<std::string&> reasonIfUnsupported) const
478{
479 if (descriptor.m_DilationY == 1 && descriptor.m_DilationY == 1)
480 {
481 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
482 }
483 else
484 {
485 if (reasonIfUnsupported)
486 {
487 reasonIfUnsupported.value() = "Reference Depthwise Convolution: Dilation parameters must be 1";
488 }
489 return false;
490 }
491}
492
493
494 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100495 const TensorInfo& input1,
496 const TensorInfo& output,
497 Optional<std::string&> reasonIfUnsupported) const
498{
Sadik Armagan2999a022019-04-09 14:20:12 +0100499 bool supported = true;
500
501 std::array<DataType,3> supportedTypes = {
502 DataType::Float32,
503 DataType::QuantisedAsymm8,
504 DataType::QuantisedSymm16
505 };
506
507 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
508 "Reference division: input 0 is not a supported type.");
509
510 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
511 "Reference division: input 1 is not a supported type.");
512
513 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
514 "Reference division: output is not a supported type.");
515
516 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
517 "Reference division: input 0 and Input 1 types are mismatched");
518
519 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
520 "Reference division: input and output types are mismatched");
521
522 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
523 "Reference division: shapes are not suitable for implicit broadcast.");
524
525 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100526}
527
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000528bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
529 const TensorInfo& input1,
530 const TensorInfo& output,
531 Optional<std::string&> reasonIfUnsupported) const
532{
533 ignore_unused(input0);
534 ignore_unused(input1);
535 ignore_unused(output);
536 ignore_unused(reasonIfUnsupported);
537 return IsSupportedForDataTypeRef(reasonIfUnsupported,
538 input0.GetDataType(),
539 &TrueFunc<>,
540 &TrueFunc<>);
541}
542
arovir011c7c81b2018-10-08 11:34:28 +0100543bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
544 const FakeQuantizationDescriptor& descriptor,
545 Optional<std::string&> reasonIfUnsupported) const
546{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100547 ignore_unused(descriptor);
548 return IsSupportedForDataTypeRef(reasonIfUnsupported,
549 input.GetDataType(),
550 &TrueFunc<>,
551 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100552}
553
554bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
555 const TensorInfo& output,
556 Optional<std::string&> reasonIfUnsupported) const
557{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100558 ignore_unused(output);
559 return IsSupportedForDataTypeRef(reasonIfUnsupported,
560 input.GetDataType(),
561 &TrueFunc<>,
562 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100563}
564
565bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
566 const TensorInfo& output,
567 const TensorInfo& weights,
568 const TensorInfo& biases,
569 const FullyConnectedDescriptor& descriptor,
570 Optional<std::string&> reasonIfUnsupported) const
571{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100572 ignore_unused(output);
573 ignore_unused(weights);
574 ignore_unused(biases);
575 ignore_unused(descriptor);
576 return IsSupportedForDataTypeRef(reasonIfUnsupported,
577 input.GetDataType(),
578 &TrueFunc<>,
579 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100580}
581
narpra014951d842019-01-18 16:53:53 +0000582bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
583 const armnn::TensorInfo& input1,
584 const armnn::TensorInfo& output,
585 armnn::Optional<std::string&> reasonIfUnsupported) const
586{
587 ignore_unused(input1);
588 ignore_unused(output);
589 return IsSupportedForDataTypeRef(reasonIfUnsupported,
590 input0.GetDataType(),
591 &TrueFunc<>,
592 &TrueFunc<>);
593}
594
FrancisMurtagh878f0232018-12-19 10:56:15 +0000595bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
596 const TensorInfo& input1,
597 const TensorInfo& output,
598 Optional<std::string&> reasonIfUnsupported) const
599{
600 ignore_unused(input0);
601 ignore_unused(input1);
602 ignore_unused(output);
603 ignore_unused(reasonIfUnsupported);
604 return IsSupportedForDataTypeRef(reasonIfUnsupported,
605 input0.GetDataType(),
606 &TrueFunc<>,
607 &TrueFunc<>);
608}
609
arovir011c7c81b2018-10-08 11:34:28 +0100610bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
611 Optional<std::string&> reasonIfUnsupported) const
612{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100613 return IsSupportedForDataTypeRef(reasonIfUnsupported,
614 input.GetDataType(),
615 &TrueFunc<>,
616 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100617}
618
619bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
620 const TensorInfo& output,
621 const L2NormalizationDescriptor& descriptor,
622 Optional<std::string&> reasonIfUnsupported) const
623{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100624 ignore_unused(output);
625 ignore_unused(descriptor);
626 return IsSupportedForDataTypeRef(reasonIfUnsupported,
627 input.GetDataType(),
628 &TrueFunc<>,
629 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100630}
631
632bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
633 const TensorInfo& outputStateIn,
634 const TensorInfo& cellStateIn,
635 const TensorInfo& scratchBuffer,
636 const TensorInfo& outputStateOut,
637 const TensorInfo& cellStateOut,
638 const TensorInfo& output,
639 const LstmDescriptor& descriptor,
640 const TensorInfo& inputToForgetWeights,
641 const TensorInfo& inputToCellWeights,
642 const TensorInfo& inputToOutputWeights,
643 const TensorInfo& recurrentToForgetWeights,
644 const TensorInfo& recurrentToCellWeights,
645 const TensorInfo& recurrentToOutputWeights,
646 const TensorInfo& forgetGateBias,
647 const TensorInfo& cellBias,
648 const TensorInfo& outputGateBias,
649 const TensorInfo* inputToInputWeights,
650 const TensorInfo* recurrentToInputWeights,
651 const TensorInfo* cellToInputWeights,
652 const TensorInfo* inputGateBias,
653 const TensorInfo* projectionWeights,
654 const TensorInfo* projectionBias,
655 const TensorInfo* cellToForgetWeights,
656 const TensorInfo* cellToOutputWeights,
657 Optional<std::string&> reasonIfUnsupported) const
658{
telsoa01c577f2c2018-08-31 09:22:23 +0100659 ignore_unused(descriptor);
660 ignore_unused(inputToForgetWeights);
661 ignore_unused(inputToCellWeights);
662 ignore_unused(inputToOutputWeights);
663 ignore_unused(recurrentToForgetWeights);
664 ignore_unused(recurrentToCellWeights);
665 ignore_unused(recurrentToOutputWeights);
666 ignore_unused(forgetGateBias);
667 ignore_unused(cellBias);
668 ignore_unused(outputGateBias);
669 ignore_unused(inputToInputWeights);
670 ignore_unused(recurrentToInputWeights);
671 ignore_unused(cellToInputWeights);
672 ignore_unused(inputGateBias);
673 ignore_unused(projectionWeights);
674 ignore_unused(projectionBias);
675 ignore_unused(cellToForgetWeights);
676 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100677
678 bool supported = true;
679
680 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100681 DataType::Float32,
682 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100683 };
684
685 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
686 "Reference Lstm: input is not a supported type.");
687
688 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
689 "Reference Lstm: input and outputStateIn types are mismatched");
690
691 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
692 "Reference Lstm: input and cellStateIn types are mismatched");
693
694 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
695 "Reference Lstm: input and scratchBuffer types are mismatched");
696
697 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
698 "Reference Lstm: input and outputStateOut types are mismatched");
699
700 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
701 "Reference Lstm: input and cellStateOut types are mismatched");
702
703 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
704 "Reference Lstm: input and output types are mismatched");
705
706 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100707}
708
saoste012df12b32018-11-28 16:57:20 +0000709bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
710 const TensorInfo& input1,
711 const TensorInfo& output,
712 Optional<std::string&> reasonIfUnsupported) const
713{
Sadik Armagan2999a022019-04-09 14:20:12 +0100714 bool supported = true;
715
716 std::array<DataType,3> supportedTypes = {
717 DataType::Float32,
718 DataType::QuantisedAsymm8,
719 DataType::QuantisedSymm16
720 };
721
722 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
723 "Reference maximum: input 0 is not a supported type.");
724
725 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
726 "Reference maximum: input 1 is not a supported type.");
727
728 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
729 "Reference maximum: output is not a supported type.");
730
731 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
732 "Reference maximum: input 0 and Input 1 types are mismatched");
733
734 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
735 "Reference maximum: input and output types are mismatched");
736
737 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
738 "Reference maximum: shapes are not suitable for implicit broadcast.");
739
740 return supported;
saoste012df12b32018-11-28 16:57:20 +0000741}
742
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100743bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
744 const TensorInfo& output,
745 const MeanDescriptor& descriptor,
746 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100747{
narpra011e4c31d2018-09-28 11:07:51 +0100748 ignore_unused(output);
749 ignore_unused(descriptor);
750 return IsSupportedForDataTypeRef(reasonIfUnsupported,
751 input.GetDataType(),
752 &TrueFunc<>,
753 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100754}
755
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100756bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000757 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100758 const OriginsDescriptor& descriptor,
759 Optional<std::string&> reasonIfUnsupported) const
760{
761 ignore_unused(descriptor);
Jim Flynncbb66aa2019-05-15 13:03:54 +0100762
763 bool supported = true;
764 std::array<DataType,3> supportedTypes =
765 {
766 DataType::Float32,
767 DataType::QuantisedAsymm8,
768 DataType::QuantisedSymm16
769 };
770
771 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
772 "Reference concatenation: output type not supported");
773 for (const TensorInfo* input : inputs)
774 {
775 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
776 "Reference concatenation: input type not supported");
777
778 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
779 "Reference concatenation: input and output types mismatched.");
780 }
781
782 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100783}
784
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000785bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
786 const TensorInfo &output,
787 Optional<std::string &> reasonIfUnsupported) const
788{
789 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000790 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
791 input.GetDataType(),
792 &TrueFunc<>,
793 &TrueFunc<>,
794 &TrueFunc<>,
795 &FalseFuncI32<>,
796 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000797}
798
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000799bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
800 const TensorInfo& input1,
801 const TensorInfo& output,
802 Optional<std::string&> reasonIfUnsupported) const
803{
Sadik Armagan2999a022019-04-09 14:20:12 +0100804 bool supported = true;
805
806 std::array<DataType,3> supportedTypes = {
807 DataType::Float32,
808 DataType::QuantisedAsymm8,
809 DataType::QuantisedSymm16
810 };
811
812 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
813 "Reference minimum: input 0 is not a supported type.");
814
815 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
816 "Reference minimum: input 1 is not a supported type.");
817
818 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
819 "Reference minimum: output is not a supported type.");
820
821 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
822 "Reference minimum: input 0 and Input 1 types are mismatched");
823
824 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
825 "Reference minimum: input and output types are mismatched");
826
827 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
828 "Reference minimum: shapes are not suitable for implicit broadcast.");
829
830 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000831}
832
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100833bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
834 const TensorInfo& input1,
835 const TensorInfo& output,
836 Optional<std::string&> reasonIfUnsupported) const
837{
Sadik Armagan2999a022019-04-09 14:20:12 +0100838 bool supported = true;
839
840 std::array<DataType,3> supportedTypes = {
841 DataType::Float32,
842 DataType::QuantisedAsymm8,
843 DataType::QuantisedSymm16
844 };
845
846 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
847 "Reference multiplication: input 0 is not a supported type.");
848
849 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
850 "Reference multiplication: input 1 is not a supported type.");
851
852 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
853 "Reference multiplication: output is not a supported type.");
854
855 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
856 "Reference multiplication: input 0 and Input 1 types are mismatched");
857
858 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
859 "Reference multiplication: input and output types are mismatched");
860
861 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
862 "Reference multiplication: shapes are not suitable for implicit broadcast.");
863
864 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100865}
866
867bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
868 const TensorInfo& output,
869 const NormalizationDescriptor& descriptor,
870 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100871{
872 ignore_unused(output);
873 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100874 return IsSupportedForDataTypeRef(reasonIfUnsupported,
875 input.GetDataType(),
876 &TrueFunc<>,
877 &FalseFuncU8<>);
878}
879
880bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
881 Optional<std::string&> reasonIfUnsupported) const
882{
kevmay012b4d88e2019-01-24 14:05:09 +0000883 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
884 output.GetDataType(),
885 &TrueFunc<>,
886 &TrueFunc<>,
887 &TrueFunc<>,
888 &FalseFuncI32<>,
889 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100890}
891
892bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
893 const TensorInfo& output,
894 const PadDescriptor& descriptor,
895 Optional<std::string&> reasonIfUnsupported) const
896{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100897 ignore_unused(output);
898 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000899 return IsSupportedForDataTypeRef(reasonIfUnsupported,
900 input.GetDataType(),
901 &TrueFunc<>,
902 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100903}
904
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100905bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
906 const TensorInfo& output,
907 const PermuteDescriptor& descriptor,
908 Optional<std::string&> reasonIfUnsupported) const
909{
910 ignore_unused(output);
911 ignore_unused(descriptor);
912 return IsSupportedForDataTypeRef(reasonIfUnsupported,
913 input.GetDataType(),
914 &TrueFunc<>,
915 &TrueFunc<>);
916}
917
918bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
919 const TensorInfo& output,
920 const Pooling2dDescriptor& descriptor,
921 Optional<std::string&> reasonIfUnsupported) const
922{
923 ignore_unused(output);
924 ignore_unused(descriptor);
925 return IsSupportedForDataTypeRef(reasonIfUnsupported,
926 input.GetDataType(),
927 &TrueFunc<>,
928 &TrueFunc<>);
929}
930
Derek Lamberti5f400d62019-03-25 15:41:58 +0000931bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
932 const TensorInfo& output,
933 Optional<std::string&> reasonIfUnsupported) const
934{
935 bool supported = true;
936
937 // Define supported output types.
938 std::array<DataType,2> supportedInputTypes = {
939 DataType::Float32,
940 };
941
942 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
943 "Reference quantize: input type not supported.");
944
945 // Define supported output types.
946 std::array<DataType,2> supportedOutputTypes = {
947 DataType::QuantisedAsymm8,
948 DataType::QuantisedSymm16
949 };
950 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
951 "Reference quantize: output type not supported.");
952
953 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
954 "Reference quantize: input and output shapes have different num total elements.");
955
956 return supported;
957}
958
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100959bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000960 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100961 Optional<std::string&> reasonIfUnsupported) const
962{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000963 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100964 return IsSupportedForDataTypeRef(reasonIfUnsupported,
965 input.GetDataType(),
966 &TrueFunc<>,
967 &TrueFunc<>);
968}
969
970bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000971 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100972 Optional<std::string&> reasonIfUnsupported) const
973{
Sadik Armaganc625f002018-12-17 11:32:16 +0000974 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100975 return IsSupportedForDataTypeRef(reasonIfUnsupported,
976 input.GetDataType(),
977 &TrueFunc<>,
978 &TrueFunc<>);
979}
980
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000981bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
982 const TensorInfo& output,
983 Optional<std::string&> reasonIfUnsupported) const
984{
985 ignore_unused(output);
986 return IsSupportedForDataTypeRef(reasonIfUnsupported,
987 input.GetDataType(),
988 &TrueFunc<>,
989 &FalseFuncU8<>);
990}
991
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100992bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
993 const TensorInfo& output,
994 const SoftmaxDescriptor& descriptor,
995 Optional<std::string&> reasonIfUnsupported) const
996{
997 ignore_unused(output);
998 ignore_unused(descriptor);
999 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1000 input.GetDataType(),
1001 &TrueFunc<>,
1002 &TrueFunc<>);
1003}
1004
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001005bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1006 const TensorInfo& output,
1007 const SpaceToBatchNdDescriptor& descriptor,
1008 Optional<std::string&> reasonIfUnsupported) const
1009{
1010 ignore_unused(output);
1011 ignore_unused(descriptor);
1012 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1013 input.GetDataType(),
1014 &TrueFunc<>,
1015 &TrueFunc<>);
1016}
1017
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001018bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1019 const ViewsDescriptor& descriptor,
1020 Optional<std::string&> reasonIfUnsupported) const
1021{
1022 ignore_unused(descriptor);
1023 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1024 input.GetDataType(),
1025 &TrueFunc<>,
1026 &TrueFunc<>);
1027}
1028
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001029bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1030 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1031 const ViewsDescriptor& descriptor,
1032 Optional<std::string&> reasonIfUnsupported) const
1033{
1034 ignore_unused(descriptor);
1035 ignore_unused(outputs);
1036 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1037 input.GetDataType(),
1038 &TrueFunc<>,
1039 &TrueFunc<>);
1040}
1041
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001042bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1043 const TensorInfo& output,
1044 const StridedSliceDescriptor& descriptor,
1045 Optional<std::string&> reasonIfUnsupported) const
1046{
1047 ignore_unused(output);
1048 ignore_unused(descriptor);
1049 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1050 input.GetDataType(),
1051 &TrueFunc<>,
1052 &TrueFunc<>);
1053}
1054
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001055bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1056 const TensorInfo& input1,
1057 const TensorInfo& output,
1058 Optional<std::string&> reasonIfUnsupported) const
1059{
Sadik Armagan2999a022019-04-09 14:20:12 +01001060 bool supported = true;
1061
1062 std::array<DataType,3> supportedTypes = {
1063 DataType::Float32,
1064 DataType::QuantisedAsymm8,
1065 DataType::QuantisedSymm16
1066 };
1067
1068 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1069 "Reference subtraction: input 0 is not a supported type.");
1070
1071 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1072 "Reference subtraction: input 1 is not a supported type.");
1073
1074 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1075 "Reference subtraction: output is not a supported type.");
1076
1077 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1078 "Reference subtraction: input 0 and Input 1 types are mismatched");
1079
1080 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1081 "Reference subtraction: input and output types are mismatched");
1082
1083 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1084 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1085
1086 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001087}
1088
arovir011c7c81b2018-10-08 11:34:28 +01001089} // namespace armnn