blob: a1d8e7de81a633c449f3375920f1062bc5d08edb [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010015
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
19#include <algorithm>
20#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
Derek Lamberti50db4e82019-03-13 14:16:15 +000049
50namespace
51{
52template<typename F>
53bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
54{
55 bool supported = rule();
56 if (!supported && reason)
57 {
58 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
59 }
60 return supported;
61}
62
63struct Rule
64{
65 bool operator()() const
66 {
67 return m_Res;
68 }
69
70 bool m_Res = true;
71};
72
Derek Lamberti2a434a82019-03-20 13:07:57 +000073template<typename T>
74bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000075{
76 return true;
77}
78
79template<typename T, typename... Rest>
80bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
81{
82 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
83
Derek Lamberti2a434a82019-03-20 13:07:57 +000084 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000085}
86
87struct TypesAreEqual : public Rule
88{
89 template<typename ... Ts>
90 TypesAreEqual(const Ts&... ts)
91 {
92 m_Res = AllTypesAreEqualImpl(ts...);
93 }
94};
95
96struct QuantizationParametersAreEqual : public Rule
97{
98 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
99 {
100 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
101 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
102 }
103};
104
105struct TypeAnyOf : public Rule
106{
107 template<typename Container>
108 TypeAnyOf(const TensorInfo& info, const Container& c)
109 {
110 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
111 {
112 return dt == info.GetDataType();
113 });
114 }
115};
116
117struct ShapesAreSameRank : public Rule
118{
119 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
120 {
121 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
122 }
123};
124
Derek Lamberti5f400d62019-03-25 15:41:58 +0000125struct ShapesAreSameTotalSize : public Rule
126{
127 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
128 {
129 m_Res = info0.GetNumElements() == info1.GetNumElements();
130 }
131};
132
Derek Lamberti50db4e82019-03-13 14:16:15 +0000133struct ShapesAreBroadcastCompatible : public Rule
134{
135 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
136 {
137 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
138 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
139 return sizeIn;
140 }
141
142 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
143 {
144 const TensorShape& shape0 = in0.GetShape();
145 const TensorShape& shape1 = in1.GetShape();
146 const TensorShape& outShape = out.GetShape();
147
148 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
149 {
150 unsigned int sizeOut = outShape[i];
151 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
152 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
153
154 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
155 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
156 }
157 }
158};
159} // namespace
160
161
arovir011c7c81b2018-10-08 11:34:28 +0100162bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
163 const TensorInfo& output,
164 const ActivationDescriptor& descriptor,
165 Optional<std::string&> reasonIfUnsupported) const
166{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000167 bool supported = true;
168
169 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100170 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000171 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100172 DataType::QuantisedAsymm8,
173 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000174 };
175
176 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
177 "Reference activation: input type not supported.");
178
179 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
180 "Reference activation: output type not supported.");
181
182 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
183 "Reference activation: input and output types mismatched.");
184
185 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
186 "Reference activation: input and output shapes are of different rank.");
187
188
189 struct ActivationFunctionSupported : public Rule
190 {
191 ActivationFunctionSupported(const ActivationDescriptor& desc)
192 {
193 switch(desc.m_Function)
194 {
195 case ActivationFunction::Abs:
196 case ActivationFunction::BoundedReLu:
197 case ActivationFunction::LeakyReLu:
198 case ActivationFunction::Linear:
199 case ActivationFunction::ReLu:
200 case ActivationFunction::Sigmoid:
201 case ActivationFunction::SoftReLu:
202 case ActivationFunction::Sqrt:
203 case ActivationFunction::Square:
204 case ActivationFunction::TanH:
205 {
206 m_Res = true;
207 break;
208 }
209 default:
210 {
211 m_Res = false;
212 break;
213 }
214 }
215 }
216 };
217
218 // Function is supported
219 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
220 "Reference activation: function not supported.");
221
222 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100223}
224
225bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
226 const TensorInfo& input1,
227 const TensorInfo& output,
228 Optional<std::string&> reasonIfUnsupported) const
229{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000230 bool supported = true;
231
Sadik Armagan2999a022019-04-09 14:20:12 +0100232 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000233 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100234 DataType::QuantisedAsymm8,
235 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000236 };
237
238 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
239 "Reference addition: input 0 is not a supported type.");
240
241 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
242 "Reference addition: input 1 is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
245 "Reference addition: output is not a supported type.");
246
247 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
248 "Reference addition: input 0 and Input 1 types are mismatched");
249
250 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
251 "Reference addition: input and output types are mismatched");
252
253 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
254 "Reference addition: shapes are not suitable for implicit broadcast.");
255
256 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100257}
258
259bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
260 const TensorInfo& output,
261 const TensorInfo& mean,
262 const TensorInfo& var,
263 const TensorInfo& beta,
264 const TensorInfo& gamma,
265 const BatchNormalizationDescriptor& descriptor,
266 Optional<std::string&> reasonIfUnsupported) const
267{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100268 ignore_unused(output);
269 ignore_unused(mean);
270 ignore_unused(var);
271 ignore_unused(beta);
272 ignore_unused(gamma);
273 ignore_unused(descriptor);
274 return IsSupportedForDataTypeRef(reasonIfUnsupported,
275 input.GetDataType(),
276 &TrueFunc<>,
277 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100278}
279
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000280bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const BatchToSpaceNdDescriptor& descriptor,
283 Optional<std::string&> reasonIfUnsupported) const
284{
285 ignore_unused(descriptor);
286 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
287 input.GetDataType(),
288 &TrueFunc<>,
289 &TrueFunc<>) &&
290 IsSupportedForDataTypeRef(reasonIfUnsupported,
291 output.GetDataType(),
292 &TrueFunc<>,
293 &TrueFunc<>));
294}
295
arovir011c7c81b2018-10-08 11:34:28 +0100296bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
297 Optional<std::string&> reasonIfUnsupported) const
298{
narpra01db2b1602019-01-23 15:23:11 +0000299 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
300 output.GetDataType(),
301 &FalseFunc<>,
302 &TrueFunc<>,
303 &TrueFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000304 &TrueFunc<>,
305 &FalseFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100306}
307
308bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
309 const TensorInfo& output,
310 Optional<std::string&> reasonIfUnsupported) const
311{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100312 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
313 input.GetDataType(),
314 &TrueFunc<>,
315 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000316 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000317 &FalseFuncI32<>,
318 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100319 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
320 output.GetDataType(),
321 &FalseOutputFuncF16<>,
322 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000323 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000324 &FalseFuncI32<>,
325 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100326}
327
328bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
329 const TensorInfo& output,
330 Optional<std::string&> reasonIfUnsupported) const
331{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100332 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
333 input.GetDataType(),
334 &FalseInputFuncF16<>,
335 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000336 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000337 &FalseFuncI32<>,
338 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100339 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
340 output.GetDataType(),
341 &TrueFunc<>,
342 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000343 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000344 &FalseFuncI32<>,
345 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100346}
347
348bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
349 const TensorInfo& output,
350 const Convolution2dDescriptor& descriptor,
351 const TensorInfo& weights,
352 const Optional<TensorInfo>& biases,
353 Optional<std::string&> reasonIfUnsupported) const
354{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100355 ignore_unused(output);
356 ignore_unused(descriptor);
357 ignore_unused(weights);
358 ignore_unused(biases);
359 return IsSupportedForDataTypeRef(reasonIfUnsupported,
360 input.GetDataType(),
361 &TrueFunc<>,
362 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100363}
364
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000365bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
366 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000367 Optional<std::string&> reasonIfUnsupported) const
368{
369 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000370 return IsSupportedForDataTypeRef(reasonIfUnsupported,
371 input.GetDataType(),
372 &TrueFunc<>,
373 &TrueFunc<>);
374}
375
arovir011c7c81b2018-10-08 11:34:28 +0100376bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
377 const TensorInfo& output,
378 const DepthwiseConvolution2dDescriptor& descriptor,
379 const TensorInfo& weights,
380 const Optional<TensorInfo>& biases,
381 Optional<std::string&> reasonIfUnsupported) const
382{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100383 ignore_unused(output);
384 ignore_unused(descriptor);
385 ignore_unused(weights);
386 ignore_unused(biases);
387 return IsSupportedForDataTypeRef(reasonIfUnsupported,
388 input.GetDataType(),
389 &TrueFunc<>,
390 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100391}
392
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000393bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
394 const TensorInfo& output,
395 Optional<std::string&> reasonIfUnsupported) const
396{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100397 bool supported = true;
398
399 std::array<DataType,2> supportedInputTypes = {
400 DataType::QuantisedAsymm8,
401 DataType::QuantisedSymm16
402 };
403
404 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
405 "Reference dequantize: input type not supported.");
406
407 std::array<DataType,2> supportedOutputTypes = {
408 DataType::Float32,
409 };
410
411 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
412 "Reference dequantize: output type not supported.");
413
414 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
415 "Reference dequantize: input and output shapes have different num total elements.");
416
417 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000418}
419
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000420bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
421 const armnn::TensorInfo& input1,
422 const armnn::DetectionPostProcessDescriptor& descriptor,
423 armnn::Optional<std::string&> reasonIfUnsupported) const
424{
425 ignore_unused(input1);
426 return IsSupportedForDataTypeRef(reasonIfUnsupported,
427 input0.GetDataType(),
428 &TrueFunc<>,
429 &TrueFunc<>);
430}
431
Pablo Tellof0bd6832019-04-26 17:58:13 +0100432bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
433 const TensorInfo& output,
434 const DepthwiseConvolution2dDescriptor& descriptor,
435 const TensorInfo& weights,
436 const Optional<TensorInfo>& biases,
437 Optional<std::string&> reasonIfUnsupported) const
438{
439 if (descriptor.m_DilationY == 1 && descriptor.m_DilationY == 1)
440 {
441 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
442 }
443 else
444 {
445 if (reasonIfUnsupported)
446 {
447 reasonIfUnsupported.value() = "Reference Depthwise Convolution: Dilation parameters must be 1";
448 }
449 return false;
450 }
451}
452
453
454 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100455 const TensorInfo& input1,
456 const TensorInfo& output,
457 Optional<std::string&> reasonIfUnsupported) const
458{
Sadik Armagan2999a022019-04-09 14:20:12 +0100459 bool supported = true;
460
461 std::array<DataType,3> supportedTypes = {
462 DataType::Float32,
463 DataType::QuantisedAsymm8,
464 DataType::QuantisedSymm16
465 };
466
467 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
468 "Reference division: input 0 is not a supported type.");
469
470 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
471 "Reference division: input 1 is not a supported type.");
472
473 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
474 "Reference division: output is not a supported type.");
475
476 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
477 "Reference division: input 0 and Input 1 types are mismatched");
478
479 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
480 "Reference division: input and output types are mismatched");
481
482 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
483 "Reference division: shapes are not suitable for implicit broadcast.");
484
485 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100486}
487
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000488bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
489 const TensorInfo& input1,
490 const TensorInfo& output,
491 Optional<std::string&> reasonIfUnsupported) const
492{
493 ignore_unused(input0);
494 ignore_unused(input1);
495 ignore_unused(output);
496 ignore_unused(reasonIfUnsupported);
497 return IsSupportedForDataTypeRef(reasonIfUnsupported,
498 input0.GetDataType(),
499 &TrueFunc<>,
500 &TrueFunc<>);
501}
502
arovir011c7c81b2018-10-08 11:34:28 +0100503bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
504 const FakeQuantizationDescriptor& descriptor,
505 Optional<std::string&> reasonIfUnsupported) const
506{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100507 ignore_unused(descriptor);
508 return IsSupportedForDataTypeRef(reasonIfUnsupported,
509 input.GetDataType(),
510 &TrueFunc<>,
511 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100512}
513
514bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
515 const TensorInfo& output,
516 Optional<std::string&> reasonIfUnsupported) const
517{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100518 ignore_unused(output);
519 return IsSupportedForDataTypeRef(reasonIfUnsupported,
520 input.GetDataType(),
521 &TrueFunc<>,
522 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100523}
524
525bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
526 const TensorInfo& output,
527 const TensorInfo& weights,
528 const TensorInfo& biases,
529 const FullyConnectedDescriptor& descriptor,
530 Optional<std::string&> reasonIfUnsupported) const
531{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100532 ignore_unused(output);
533 ignore_unused(weights);
534 ignore_unused(biases);
535 ignore_unused(descriptor);
536 return IsSupportedForDataTypeRef(reasonIfUnsupported,
537 input.GetDataType(),
538 &TrueFunc<>,
539 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100540}
541
narpra014951d842019-01-18 16:53:53 +0000542bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
543 const armnn::TensorInfo& input1,
544 const armnn::TensorInfo& output,
545 armnn::Optional<std::string&> reasonIfUnsupported) const
546{
547 ignore_unused(input1);
548 ignore_unused(output);
549 return IsSupportedForDataTypeRef(reasonIfUnsupported,
550 input0.GetDataType(),
551 &TrueFunc<>,
552 &TrueFunc<>);
553}
554
FrancisMurtagh878f0232018-12-19 10:56:15 +0000555bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
556 const TensorInfo& input1,
557 const TensorInfo& output,
558 Optional<std::string&> reasonIfUnsupported) const
559{
560 ignore_unused(input0);
561 ignore_unused(input1);
562 ignore_unused(output);
563 ignore_unused(reasonIfUnsupported);
564 return IsSupportedForDataTypeRef(reasonIfUnsupported,
565 input0.GetDataType(),
566 &TrueFunc<>,
567 &TrueFunc<>);
568}
569
arovir011c7c81b2018-10-08 11:34:28 +0100570bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
571 Optional<std::string&> reasonIfUnsupported) const
572{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100573 return IsSupportedForDataTypeRef(reasonIfUnsupported,
574 input.GetDataType(),
575 &TrueFunc<>,
576 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100577}
578
579bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
580 const TensorInfo& output,
581 const L2NormalizationDescriptor& descriptor,
582 Optional<std::string&> reasonIfUnsupported) const
583{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100584 ignore_unused(output);
585 ignore_unused(descriptor);
586 return IsSupportedForDataTypeRef(reasonIfUnsupported,
587 input.GetDataType(),
588 &TrueFunc<>,
589 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100590}
591
592bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
593 const TensorInfo& outputStateIn,
594 const TensorInfo& cellStateIn,
595 const TensorInfo& scratchBuffer,
596 const TensorInfo& outputStateOut,
597 const TensorInfo& cellStateOut,
598 const TensorInfo& output,
599 const LstmDescriptor& descriptor,
600 const TensorInfo& inputToForgetWeights,
601 const TensorInfo& inputToCellWeights,
602 const TensorInfo& inputToOutputWeights,
603 const TensorInfo& recurrentToForgetWeights,
604 const TensorInfo& recurrentToCellWeights,
605 const TensorInfo& recurrentToOutputWeights,
606 const TensorInfo& forgetGateBias,
607 const TensorInfo& cellBias,
608 const TensorInfo& outputGateBias,
609 const TensorInfo* inputToInputWeights,
610 const TensorInfo* recurrentToInputWeights,
611 const TensorInfo* cellToInputWeights,
612 const TensorInfo* inputGateBias,
613 const TensorInfo* projectionWeights,
614 const TensorInfo* projectionBias,
615 const TensorInfo* cellToForgetWeights,
616 const TensorInfo* cellToOutputWeights,
617 Optional<std::string&> reasonIfUnsupported) const
618{
telsoa01c577f2c2018-08-31 09:22:23 +0100619 ignore_unused(descriptor);
620 ignore_unused(inputToForgetWeights);
621 ignore_unused(inputToCellWeights);
622 ignore_unused(inputToOutputWeights);
623 ignore_unused(recurrentToForgetWeights);
624 ignore_unused(recurrentToCellWeights);
625 ignore_unused(recurrentToOutputWeights);
626 ignore_unused(forgetGateBias);
627 ignore_unused(cellBias);
628 ignore_unused(outputGateBias);
629 ignore_unused(inputToInputWeights);
630 ignore_unused(recurrentToInputWeights);
631 ignore_unused(cellToInputWeights);
632 ignore_unused(inputGateBias);
633 ignore_unused(projectionWeights);
634 ignore_unused(projectionBias);
635 ignore_unused(cellToForgetWeights);
636 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100637
638 bool supported = true;
639
640 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100641 DataType::Float32,
642 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100643 };
644
645 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
646 "Reference Lstm: input is not a supported type.");
647
648 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
649 "Reference Lstm: input and outputStateIn types are mismatched");
650
651 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
652 "Reference Lstm: input and cellStateIn types are mismatched");
653
654 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
655 "Reference Lstm: input and scratchBuffer types are mismatched");
656
657 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
658 "Reference Lstm: input and outputStateOut types are mismatched");
659
660 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
661 "Reference Lstm: input and cellStateOut types are mismatched");
662
663 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
664 "Reference Lstm: input and output types are mismatched");
665
666 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100667}
668
saoste012df12b32018-11-28 16:57:20 +0000669bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
670 const TensorInfo& input1,
671 const TensorInfo& output,
672 Optional<std::string&> reasonIfUnsupported) const
673{
Sadik Armagan2999a022019-04-09 14:20:12 +0100674 bool supported = true;
675
676 std::array<DataType,3> supportedTypes = {
677 DataType::Float32,
678 DataType::QuantisedAsymm8,
679 DataType::QuantisedSymm16
680 };
681
682 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
683 "Reference maximum: input 0 is not a supported type.");
684
685 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
686 "Reference maximum: input 1 is not a supported type.");
687
688 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
689 "Reference maximum: output is not a supported type.");
690
691 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
692 "Reference maximum: input 0 and Input 1 types are mismatched");
693
694 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
695 "Reference maximum: input and output types are mismatched");
696
697 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
698 "Reference maximum: shapes are not suitable for implicit broadcast.");
699
700 return supported;
saoste012df12b32018-11-28 16:57:20 +0000701}
702
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100703bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
704 const TensorInfo& output,
705 const MeanDescriptor& descriptor,
706 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100707{
narpra011e4c31d2018-09-28 11:07:51 +0100708 ignore_unused(output);
709 ignore_unused(descriptor);
710 return IsSupportedForDataTypeRef(reasonIfUnsupported,
711 input.GetDataType(),
712 &TrueFunc<>,
713 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100714}
715
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100716bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000717 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100718 const OriginsDescriptor& descriptor,
719 Optional<std::string&> reasonIfUnsupported) const
720{
721 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000722 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100723 return IsSupportedForDataTypeRef(reasonIfUnsupported,
724 inputs[0]->GetDataType(),
725 &TrueFunc<>,
726 &TrueFunc<>);
727}
728
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000729bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
730 const TensorInfo &output,
731 Optional<std::string &> reasonIfUnsupported) const
732{
733 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000734 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
735 input.GetDataType(),
736 &TrueFunc<>,
737 &TrueFunc<>,
738 &TrueFunc<>,
739 &FalseFuncI32<>,
740 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000741}
742
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000743bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
744 const TensorInfo& input1,
745 const TensorInfo& output,
746 Optional<std::string&> reasonIfUnsupported) const
747{
Sadik Armagan2999a022019-04-09 14:20:12 +0100748 bool supported = true;
749
750 std::array<DataType,3> supportedTypes = {
751 DataType::Float32,
752 DataType::QuantisedAsymm8,
753 DataType::QuantisedSymm16
754 };
755
756 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
757 "Reference minimum: input 0 is not a supported type.");
758
759 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
760 "Reference minimum: input 1 is not a supported type.");
761
762 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
763 "Reference minimum: output is not a supported type.");
764
765 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
766 "Reference minimum: input 0 and Input 1 types are mismatched");
767
768 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
769 "Reference minimum: input and output types are mismatched");
770
771 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
772 "Reference minimum: shapes are not suitable for implicit broadcast.");
773
774 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000775}
776
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100777bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
778 const TensorInfo& input1,
779 const TensorInfo& output,
780 Optional<std::string&> reasonIfUnsupported) const
781{
Sadik Armagan2999a022019-04-09 14:20:12 +0100782 bool supported = true;
783
784 std::array<DataType,3> supportedTypes = {
785 DataType::Float32,
786 DataType::QuantisedAsymm8,
787 DataType::QuantisedSymm16
788 };
789
790 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
791 "Reference multiplication: input 0 is not a supported type.");
792
793 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
794 "Reference multiplication: input 1 is not a supported type.");
795
796 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
797 "Reference multiplication: output is not a supported type.");
798
799 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
800 "Reference multiplication: input 0 and Input 1 types are mismatched");
801
802 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
803 "Reference multiplication: input and output types are mismatched");
804
805 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
806 "Reference multiplication: shapes are not suitable for implicit broadcast.");
807
808 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100809}
810
811bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
812 const TensorInfo& output,
813 const NormalizationDescriptor& descriptor,
814 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100815{
816 ignore_unused(output);
817 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100818 return IsSupportedForDataTypeRef(reasonIfUnsupported,
819 input.GetDataType(),
820 &TrueFunc<>,
821 &FalseFuncU8<>);
822}
823
824bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
825 Optional<std::string&> reasonIfUnsupported) const
826{
kevmay012b4d88e2019-01-24 14:05:09 +0000827 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
828 output.GetDataType(),
829 &TrueFunc<>,
830 &TrueFunc<>,
831 &TrueFunc<>,
832 &FalseFuncI32<>,
833 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100834}
835
836bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
837 const TensorInfo& output,
838 const PadDescriptor& descriptor,
839 Optional<std::string&> reasonIfUnsupported) const
840{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100841 ignore_unused(output);
842 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000843 return IsSupportedForDataTypeRef(reasonIfUnsupported,
844 input.GetDataType(),
845 &TrueFunc<>,
846 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100847}
848
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100849bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
850 const TensorInfo& output,
851 const PermuteDescriptor& descriptor,
852 Optional<std::string&> reasonIfUnsupported) const
853{
854 ignore_unused(output);
855 ignore_unused(descriptor);
856 return IsSupportedForDataTypeRef(reasonIfUnsupported,
857 input.GetDataType(),
858 &TrueFunc<>,
859 &TrueFunc<>);
860}
861
862bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
863 const TensorInfo& output,
864 const Pooling2dDescriptor& descriptor,
865 Optional<std::string&> reasonIfUnsupported) const
866{
867 ignore_unused(output);
868 ignore_unused(descriptor);
869 return IsSupportedForDataTypeRef(reasonIfUnsupported,
870 input.GetDataType(),
871 &TrueFunc<>,
872 &TrueFunc<>);
873}
874
Derek Lamberti5f400d62019-03-25 15:41:58 +0000875bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
876 const TensorInfo& output,
877 Optional<std::string&> reasonIfUnsupported) const
878{
879 bool supported = true;
880
881 // Define supported output types.
882 std::array<DataType,2> supportedInputTypes = {
883 DataType::Float32,
884 };
885
886 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
887 "Reference quantize: input type not supported.");
888
889 // Define supported output types.
890 std::array<DataType,2> supportedOutputTypes = {
891 DataType::QuantisedAsymm8,
892 DataType::QuantisedSymm16
893 };
894 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
895 "Reference quantize: output type not supported.");
896
897 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
898 "Reference quantize: input and output shapes have different num total elements.");
899
900 return supported;
901}
902
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100903bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000904 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100905 Optional<std::string&> reasonIfUnsupported) const
906{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000907 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100908 return IsSupportedForDataTypeRef(reasonIfUnsupported,
909 input.GetDataType(),
910 &TrueFunc<>,
911 &TrueFunc<>);
912}
913
914bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000915 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100916 Optional<std::string&> reasonIfUnsupported) const
917{
Sadik Armaganc625f002018-12-17 11:32:16 +0000918 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100919 return IsSupportedForDataTypeRef(reasonIfUnsupported,
920 input.GetDataType(),
921 &TrueFunc<>,
922 &TrueFunc<>);
923}
924
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000925bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
926 const TensorInfo& output,
927 Optional<std::string&> reasonIfUnsupported) const
928{
929 ignore_unused(output);
930 return IsSupportedForDataTypeRef(reasonIfUnsupported,
931 input.GetDataType(),
932 &TrueFunc<>,
933 &FalseFuncU8<>);
934}
935
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100936bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
937 const TensorInfo& output,
938 const SoftmaxDescriptor& descriptor,
939 Optional<std::string&> reasonIfUnsupported) const
940{
941 ignore_unused(output);
942 ignore_unused(descriptor);
943 return IsSupportedForDataTypeRef(reasonIfUnsupported,
944 input.GetDataType(),
945 &TrueFunc<>,
946 &TrueFunc<>);
947}
948
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000949bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
950 const TensorInfo& output,
951 const SpaceToBatchNdDescriptor& descriptor,
952 Optional<std::string&> reasonIfUnsupported) const
953{
954 ignore_unused(output);
955 ignore_unused(descriptor);
956 return IsSupportedForDataTypeRef(reasonIfUnsupported,
957 input.GetDataType(),
958 &TrueFunc<>,
959 &TrueFunc<>);
960}
961
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100962bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
963 const ViewsDescriptor& descriptor,
964 Optional<std::string&> reasonIfUnsupported) const
965{
966 ignore_unused(descriptor);
967 return IsSupportedForDataTypeRef(reasonIfUnsupported,
968 input.GetDataType(),
969 &TrueFunc<>,
970 &TrueFunc<>);
971}
972
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000973bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
974 const TensorInfo& output,
975 const StridedSliceDescriptor& descriptor,
976 Optional<std::string&> reasonIfUnsupported) const
977{
978 ignore_unused(output);
979 ignore_unused(descriptor);
980 return IsSupportedForDataTypeRef(reasonIfUnsupported,
981 input.GetDataType(),
982 &TrueFunc<>,
983 &TrueFunc<>);
984}
985
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100986bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
987 const TensorInfo& input1,
988 const TensorInfo& output,
989 Optional<std::string&> reasonIfUnsupported) const
990{
Sadik Armagan2999a022019-04-09 14:20:12 +0100991 bool supported = true;
992
993 std::array<DataType,3> supportedTypes = {
994 DataType::Float32,
995 DataType::QuantisedAsymm8,
996 DataType::QuantisedSymm16
997 };
998
999 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1000 "Reference subtraction: input 0 is not a supported type.");
1001
1002 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1003 "Reference subtraction: input 1 is not a supported type.");
1004
1005 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1006 "Reference subtraction: output is not a supported type.");
1007
1008 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1009 "Reference subtraction: input 0 and Input 1 types are mismatched");
1010
1011 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1012 "Reference subtraction: input and output types are mismatched");
1013
1014 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1015 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1016
1017 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001018}
1019
arovir011c7c81b2018-10-08 11:34:28 +01001020} // namespace armnn