blob: 8eded84e09289bbb4d04add9e28fb9a6708b2315 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010015
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
19#include <algorithm>
20#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
Derek Lamberti50db4e82019-03-13 14:16:15 +000049
50namespace
51{
52template<typename F>
53bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
54{
55 bool supported = rule();
56 if (!supported && reason)
57 {
58 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
59 }
60 return supported;
61}
62
63struct Rule
64{
65 bool operator()() const
66 {
67 return m_Res;
68 }
69
70 bool m_Res = true;
71};
72
Derek Lamberti2a434a82019-03-20 13:07:57 +000073template<typename T>
74bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000075{
76 return true;
77}
78
79template<typename T, typename... Rest>
80bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
81{
82 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
83
Derek Lamberti2a434a82019-03-20 13:07:57 +000084 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000085}
86
87struct TypesAreEqual : public Rule
88{
89 template<typename ... Ts>
90 TypesAreEqual(const Ts&... ts)
91 {
92 m_Res = AllTypesAreEqualImpl(ts...);
93 }
94};
95
96struct QuantizationParametersAreEqual : public Rule
97{
98 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
99 {
100 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
101 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
102 }
103};
104
105struct TypeAnyOf : public Rule
106{
107 template<typename Container>
108 TypeAnyOf(const TensorInfo& info, const Container& c)
109 {
110 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
111 {
112 return dt == info.GetDataType();
113 });
114 }
115};
116
117struct ShapesAreSameRank : public Rule
118{
119 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
120 {
121 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
122 }
123};
124
Derek Lamberti5f400d62019-03-25 15:41:58 +0000125struct ShapesAreSameTotalSize : public Rule
126{
127 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
128 {
129 m_Res = info0.GetNumElements() == info1.GetNumElements();
130 }
131};
132
Derek Lamberti50db4e82019-03-13 14:16:15 +0000133struct ShapesAreBroadcastCompatible : public Rule
134{
135 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
136 {
137 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
138 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
139 return sizeIn;
140 }
141
142 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
143 {
144 const TensorShape& shape0 = in0.GetShape();
145 const TensorShape& shape1 = in1.GetShape();
146 const TensorShape& outShape = out.GetShape();
147
148 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
149 {
150 unsigned int sizeOut = outShape[i];
151 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
152 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
153
154 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
155 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
156 }
157 }
158};
159} // namespace
160
161
arovir011c7c81b2018-10-08 11:34:28 +0100162bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
163 const TensorInfo& output,
164 const ActivationDescriptor& descriptor,
165 Optional<std::string&> reasonIfUnsupported) const
166{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000167 bool supported = true;
168
169 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100170 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000171 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100172 DataType::QuantisedAsymm8,
173 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000174 };
175
176 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
177 "Reference activation: input type not supported.");
178
179 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
180 "Reference activation: output type not supported.");
181
182 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
183 "Reference activation: input and output types mismatched.");
184
185 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
186 "Reference activation: input and output shapes are of different rank.");
187
188
189 struct ActivationFunctionSupported : public Rule
190 {
191 ActivationFunctionSupported(const ActivationDescriptor& desc)
192 {
193 switch(desc.m_Function)
194 {
195 case ActivationFunction::Abs:
196 case ActivationFunction::BoundedReLu:
197 case ActivationFunction::LeakyReLu:
198 case ActivationFunction::Linear:
199 case ActivationFunction::ReLu:
200 case ActivationFunction::Sigmoid:
201 case ActivationFunction::SoftReLu:
202 case ActivationFunction::Sqrt:
203 case ActivationFunction::Square:
204 case ActivationFunction::TanH:
205 {
206 m_Res = true;
207 break;
208 }
209 default:
210 {
211 m_Res = false;
212 break;
213 }
214 }
215 }
216 };
217
218 // Function is supported
219 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
220 "Reference activation: function not supported.");
221
222 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100223}
224
225bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
226 const TensorInfo& input1,
227 const TensorInfo& output,
228 Optional<std::string&> reasonIfUnsupported) const
229{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000230 bool supported = true;
231
Sadik Armagan2999a022019-04-09 14:20:12 +0100232 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000233 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100234 DataType::QuantisedAsymm8,
235 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000236 };
237
238 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
239 "Reference addition: input 0 is not a supported type.");
240
241 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
242 "Reference addition: input 1 is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
245 "Reference addition: output is not a supported type.");
246
247 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
248 "Reference addition: input 0 and Input 1 types are mismatched");
249
250 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
251 "Reference addition: input and output types are mismatched");
252
253 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
254 "Reference addition: shapes are not suitable for implicit broadcast.");
255
256 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100257}
258
259bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
260 const TensorInfo& output,
261 const TensorInfo& mean,
262 const TensorInfo& var,
263 const TensorInfo& beta,
264 const TensorInfo& gamma,
265 const BatchNormalizationDescriptor& descriptor,
266 Optional<std::string&> reasonIfUnsupported) const
267{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100268 ignore_unused(output);
269 ignore_unused(mean);
270 ignore_unused(var);
271 ignore_unused(beta);
272 ignore_unused(gamma);
273 ignore_unused(descriptor);
274 return IsSupportedForDataTypeRef(reasonIfUnsupported,
275 input.GetDataType(),
276 &TrueFunc<>,
277 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100278}
279
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000280bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const BatchToSpaceNdDescriptor& descriptor,
283 Optional<std::string&> reasonIfUnsupported) const
284{
285 ignore_unused(descriptor);
286 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
287 input.GetDataType(),
288 &TrueFunc<>,
289 &TrueFunc<>) &&
290 IsSupportedForDataTypeRef(reasonIfUnsupported,
291 output.GetDataType(),
292 &TrueFunc<>,
293 &TrueFunc<>));
294}
295
arovir011c7c81b2018-10-08 11:34:28 +0100296bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
297 Optional<std::string&> reasonIfUnsupported) const
298{
narpra01db2b1602019-01-23 15:23:11 +0000299 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
300 output.GetDataType(),
301 &FalseFunc<>,
302 &TrueFunc<>,
303 &TrueFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000304 &TrueFunc<>,
305 &FalseFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100306}
307
308bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
309 const TensorInfo& output,
310 Optional<std::string&> reasonIfUnsupported) const
311{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100312 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
313 input.GetDataType(),
314 &TrueFunc<>,
315 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000316 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000317 &FalseFuncI32<>,
318 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100319 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
320 output.GetDataType(),
321 &FalseOutputFuncF16<>,
322 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000323 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000324 &FalseFuncI32<>,
325 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100326}
327
328bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
329 const TensorInfo& output,
330 Optional<std::string&> reasonIfUnsupported) const
331{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100332 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
333 input.GetDataType(),
334 &FalseInputFuncF16<>,
335 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000336 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000337 &FalseFuncI32<>,
338 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100339 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
340 output.GetDataType(),
341 &TrueFunc<>,
342 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000343 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000344 &FalseFuncI32<>,
345 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100346}
347
348bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
349 const TensorInfo& output,
350 const Convolution2dDescriptor& descriptor,
351 const TensorInfo& weights,
352 const Optional<TensorInfo>& biases,
353 Optional<std::string&> reasonIfUnsupported) const
354{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100355 ignore_unused(output);
356 ignore_unused(descriptor);
357 ignore_unused(weights);
358 ignore_unused(biases);
359 return IsSupportedForDataTypeRef(reasonIfUnsupported,
360 input.GetDataType(),
361 &TrueFunc<>,
362 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100363}
364
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000365bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
366 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000367 Optional<std::string&> reasonIfUnsupported) const
368{
369 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000370 return IsSupportedForDataTypeRef(reasonIfUnsupported,
371 input.GetDataType(),
372 &TrueFunc<>,
373 &TrueFunc<>);
374}
375
arovir011c7c81b2018-10-08 11:34:28 +0100376bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
377 const TensorInfo& output,
378 const DepthwiseConvolution2dDescriptor& descriptor,
379 const TensorInfo& weights,
380 const Optional<TensorInfo>& biases,
381 Optional<std::string&> reasonIfUnsupported) const
382{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100383 ignore_unused(output);
384 ignore_unused(descriptor);
385 ignore_unused(weights);
386 ignore_unused(biases);
387 return IsSupportedForDataTypeRef(reasonIfUnsupported,
388 input.GetDataType(),
389 &TrueFunc<>,
390 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100391}
392
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000393bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
394 const TensorInfo& output,
395 Optional<std::string&> reasonIfUnsupported) const
396{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100397 bool supported = true;
398
399 std::array<DataType,2> supportedInputTypes = {
400 DataType::QuantisedAsymm8,
401 DataType::QuantisedSymm16
402 };
403
404 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
405 "Reference dequantize: input type not supported.");
406
407 std::array<DataType,2> supportedOutputTypes = {
408 DataType::Float32,
409 };
410
411 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
412 "Reference dequantize: output type not supported.");
413
414 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
415 "Reference dequantize: input and output shapes have different num total elements.");
416
417 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000418}
419
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000420bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
421 const armnn::TensorInfo& input1,
422 const armnn::DetectionPostProcessDescriptor& descriptor,
423 armnn::Optional<std::string&> reasonIfUnsupported) const
424{
425 ignore_unused(input1);
426 return IsSupportedForDataTypeRef(reasonIfUnsupported,
427 input0.GetDataType(),
428 &TrueFunc<>,
429 &TrueFunc<>);
430}
431
arovir011c7c81b2018-10-08 11:34:28 +0100432bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
433 const TensorInfo& input1,
434 const TensorInfo& output,
435 Optional<std::string&> reasonIfUnsupported) const
436{
Sadik Armagan2999a022019-04-09 14:20:12 +0100437 bool supported = true;
438
439 std::array<DataType,3> supportedTypes = {
440 DataType::Float32,
441 DataType::QuantisedAsymm8,
442 DataType::QuantisedSymm16
443 };
444
445 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
446 "Reference division: input 0 is not a supported type.");
447
448 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
449 "Reference division: input 1 is not a supported type.");
450
451 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
452 "Reference division: output is not a supported type.");
453
454 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
455 "Reference division: input 0 and Input 1 types are mismatched");
456
457 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
458 "Reference division: input and output types are mismatched");
459
460 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
461 "Reference division: shapes are not suitable for implicit broadcast.");
462
463 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100464}
465
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000466bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
467 const TensorInfo& input1,
468 const TensorInfo& output,
469 Optional<std::string&> reasonIfUnsupported) const
470{
471 ignore_unused(input0);
472 ignore_unused(input1);
473 ignore_unused(output);
474 ignore_unused(reasonIfUnsupported);
475 return IsSupportedForDataTypeRef(reasonIfUnsupported,
476 input0.GetDataType(),
477 &TrueFunc<>,
478 &TrueFunc<>);
479}
480
arovir011c7c81b2018-10-08 11:34:28 +0100481bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
482 const FakeQuantizationDescriptor& descriptor,
483 Optional<std::string&> reasonIfUnsupported) const
484{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100485 ignore_unused(descriptor);
486 return IsSupportedForDataTypeRef(reasonIfUnsupported,
487 input.GetDataType(),
488 &TrueFunc<>,
489 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100490}
491
492bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
493 const TensorInfo& output,
494 Optional<std::string&> reasonIfUnsupported) const
495{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100496 ignore_unused(output);
497 return IsSupportedForDataTypeRef(reasonIfUnsupported,
498 input.GetDataType(),
499 &TrueFunc<>,
500 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100501}
502
503bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
504 const TensorInfo& output,
505 const TensorInfo& weights,
506 const TensorInfo& biases,
507 const FullyConnectedDescriptor& descriptor,
508 Optional<std::string&> reasonIfUnsupported) const
509{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100510 ignore_unused(output);
511 ignore_unused(weights);
512 ignore_unused(biases);
513 ignore_unused(descriptor);
514 return IsSupportedForDataTypeRef(reasonIfUnsupported,
515 input.GetDataType(),
516 &TrueFunc<>,
517 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100518}
519
narpra014951d842019-01-18 16:53:53 +0000520bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
521 const armnn::TensorInfo& input1,
522 const armnn::TensorInfo& output,
523 armnn::Optional<std::string&> reasonIfUnsupported) const
524{
525 ignore_unused(input1);
526 ignore_unused(output);
527 return IsSupportedForDataTypeRef(reasonIfUnsupported,
528 input0.GetDataType(),
529 &TrueFunc<>,
530 &TrueFunc<>);
531}
532
FrancisMurtagh878f0232018-12-19 10:56:15 +0000533bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
534 const TensorInfo& input1,
535 const TensorInfo& output,
536 Optional<std::string&> reasonIfUnsupported) const
537{
538 ignore_unused(input0);
539 ignore_unused(input1);
540 ignore_unused(output);
541 ignore_unused(reasonIfUnsupported);
542 return IsSupportedForDataTypeRef(reasonIfUnsupported,
543 input0.GetDataType(),
544 &TrueFunc<>,
545 &TrueFunc<>);
546}
547
arovir011c7c81b2018-10-08 11:34:28 +0100548bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
549 Optional<std::string&> reasonIfUnsupported) const
550{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100551 return IsSupportedForDataTypeRef(reasonIfUnsupported,
552 input.GetDataType(),
553 &TrueFunc<>,
554 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100555}
556
557bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
558 const TensorInfo& output,
559 const L2NormalizationDescriptor& descriptor,
560 Optional<std::string&> reasonIfUnsupported) const
561{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100562 ignore_unused(output);
563 ignore_unused(descriptor);
564 return IsSupportedForDataTypeRef(reasonIfUnsupported,
565 input.GetDataType(),
566 &TrueFunc<>,
567 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100568}
569
570bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
571 const TensorInfo& outputStateIn,
572 const TensorInfo& cellStateIn,
573 const TensorInfo& scratchBuffer,
574 const TensorInfo& outputStateOut,
575 const TensorInfo& cellStateOut,
576 const TensorInfo& output,
577 const LstmDescriptor& descriptor,
578 const TensorInfo& inputToForgetWeights,
579 const TensorInfo& inputToCellWeights,
580 const TensorInfo& inputToOutputWeights,
581 const TensorInfo& recurrentToForgetWeights,
582 const TensorInfo& recurrentToCellWeights,
583 const TensorInfo& recurrentToOutputWeights,
584 const TensorInfo& forgetGateBias,
585 const TensorInfo& cellBias,
586 const TensorInfo& outputGateBias,
587 const TensorInfo* inputToInputWeights,
588 const TensorInfo* recurrentToInputWeights,
589 const TensorInfo* cellToInputWeights,
590 const TensorInfo* inputGateBias,
591 const TensorInfo* projectionWeights,
592 const TensorInfo* projectionBias,
593 const TensorInfo* cellToForgetWeights,
594 const TensorInfo* cellToOutputWeights,
595 Optional<std::string&> reasonIfUnsupported) const
596{
telsoa01c577f2c2018-08-31 09:22:23 +0100597 ignore_unused(descriptor);
598 ignore_unused(inputToForgetWeights);
599 ignore_unused(inputToCellWeights);
600 ignore_unused(inputToOutputWeights);
601 ignore_unused(recurrentToForgetWeights);
602 ignore_unused(recurrentToCellWeights);
603 ignore_unused(recurrentToOutputWeights);
604 ignore_unused(forgetGateBias);
605 ignore_unused(cellBias);
606 ignore_unused(outputGateBias);
607 ignore_unused(inputToInputWeights);
608 ignore_unused(recurrentToInputWeights);
609 ignore_unused(cellToInputWeights);
610 ignore_unused(inputGateBias);
611 ignore_unused(projectionWeights);
612 ignore_unused(projectionBias);
613 ignore_unused(cellToForgetWeights);
614 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100615
616 bool supported = true;
617
618 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100619 DataType::Float32,
620 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100621 };
622
623 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
624 "Reference Lstm: input is not a supported type.");
625
626 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
627 "Reference Lstm: input and outputStateIn types are mismatched");
628
629 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
630 "Reference Lstm: input and cellStateIn types are mismatched");
631
632 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
633 "Reference Lstm: input and scratchBuffer types are mismatched");
634
635 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
636 "Reference Lstm: input and outputStateOut types are mismatched");
637
638 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
639 "Reference Lstm: input and cellStateOut types are mismatched");
640
641 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
642 "Reference Lstm: input and output types are mismatched");
643
644 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100645}
646
saoste012df12b32018-11-28 16:57:20 +0000647bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
648 const TensorInfo& input1,
649 const TensorInfo& output,
650 Optional<std::string&> reasonIfUnsupported) const
651{
Sadik Armagan2999a022019-04-09 14:20:12 +0100652 bool supported = true;
653
654 std::array<DataType,3> supportedTypes = {
655 DataType::Float32,
656 DataType::QuantisedAsymm8,
657 DataType::QuantisedSymm16
658 };
659
660 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
661 "Reference maximum: input 0 is not a supported type.");
662
663 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
664 "Reference maximum: input 1 is not a supported type.");
665
666 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
667 "Reference maximum: output is not a supported type.");
668
669 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
670 "Reference maximum: input 0 and Input 1 types are mismatched");
671
672 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
673 "Reference maximum: input and output types are mismatched");
674
675 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
676 "Reference maximum: shapes are not suitable for implicit broadcast.");
677
678 return supported;
saoste012df12b32018-11-28 16:57:20 +0000679}
680
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100681bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
682 const TensorInfo& output,
683 const MeanDescriptor& descriptor,
684 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100685{
narpra011e4c31d2018-09-28 11:07:51 +0100686 ignore_unused(output);
687 ignore_unused(descriptor);
688 return IsSupportedForDataTypeRef(reasonIfUnsupported,
689 input.GetDataType(),
690 &TrueFunc<>,
691 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100692}
693
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100694bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000695 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100696 const OriginsDescriptor& descriptor,
697 Optional<std::string&> reasonIfUnsupported) const
698{
699 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000700 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100701 return IsSupportedForDataTypeRef(reasonIfUnsupported,
702 inputs[0]->GetDataType(),
703 &TrueFunc<>,
704 &TrueFunc<>);
705}
706
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000707bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
708 const TensorInfo &output,
709 Optional<std::string &> reasonIfUnsupported) const
710{
711 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000712 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
713 input.GetDataType(),
714 &TrueFunc<>,
715 &TrueFunc<>,
716 &TrueFunc<>,
717 &FalseFuncI32<>,
718 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000719}
720
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000721bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
722 const TensorInfo& input1,
723 const TensorInfo& output,
724 Optional<std::string&> reasonIfUnsupported) const
725{
Sadik Armagan2999a022019-04-09 14:20:12 +0100726 bool supported = true;
727
728 std::array<DataType,3> supportedTypes = {
729 DataType::Float32,
730 DataType::QuantisedAsymm8,
731 DataType::QuantisedSymm16
732 };
733
734 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
735 "Reference minimum: input 0 is not a supported type.");
736
737 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
738 "Reference minimum: input 1 is not a supported type.");
739
740 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
741 "Reference minimum: output is not a supported type.");
742
743 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
744 "Reference minimum: input 0 and Input 1 types are mismatched");
745
746 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
747 "Reference minimum: input and output types are mismatched");
748
749 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
750 "Reference minimum: shapes are not suitable for implicit broadcast.");
751
752 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000753}
754
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100755bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
756 const TensorInfo& input1,
757 const TensorInfo& output,
758 Optional<std::string&> reasonIfUnsupported) const
759{
Sadik Armagan2999a022019-04-09 14:20:12 +0100760 bool supported = true;
761
762 std::array<DataType,3> supportedTypes = {
763 DataType::Float32,
764 DataType::QuantisedAsymm8,
765 DataType::QuantisedSymm16
766 };
767
768 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
769 "Reference multiplication: input 0 is not a supported type.");
770
771 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
772 "Reference multiplication: input 1 is not a supported type.");
773
774 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
775 "Reference multiplication: output is not a supported type.");
776
777 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
778 "Reference multiplication: input 0 and Input 1 types are mismatched");
779
780 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
781 "Reference multiplication: input and output types are mismatched");
782
783 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
784 "Reference multiplication: shapes are not suitable for implicit broadcast.");
785
786 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100787}
788
789bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
790 const TensorInfo& output,
791 const NormalizationDescriptor& descriptor,
792 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100793{
794 ignore_unused(output);
795 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100796 return IsSupportedForDataTypeRef(reasonIfUnsupported,
797 input.GetDataType(),
798 &TrueFunc<>,
799 &FalseFuncU8<>);
800}
801
802bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
803 Optional<std::string&> reasonIfUnsupported) const
804{
kevmay012b4d88e2019-01-24 14:05:09 +0000805 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
806 output.GetDataType(),
807 &TrueFunc<>,
808 &TrueFunc<>,
809 &TrueFunc<>,
810 &FalseFuncI32<>,
811 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100812}
813
814bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
815 const TensorInfo& output,
816 const PadDescriptor& descriptor,
817 Optional<std::string&> reasonIfUnsupported) const
818{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100819 ignore_unused(output);
820 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000821 return IsSupportedForDataTypeRef(reasonIfUnsupported,
822 input.GetDataType(),
823 &TrueFunc<>,
824 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100825}
826
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100827bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
828 const TensorInfo& output,
829 const PermuteDescriptor& descriptor,
830 Optional<std::string&> reasonIfUnsupported) const
831{
832 ignore_unused(output);
833 ignore_unused(descriptor);
834 return IsSupportedForDataTypeRef(reasonIfUnsupported,
835 input.GetDataType(),
836 &TrueFunc<>,
837 &TrueFunc<>);
838}
839
840bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
841 const TensorInfo& output,
842 const Pooling2dDescriptor& descriptor,
843 Optional<std::string&> reasonIfUnsupported) const
844{
845 ignore_unused(output);
846 ignore_unused(descriptor);
847 return IsSupportedForDataTypeRef(reasonIfUnsupported,
848 input.GetDataType(),
849 &TrueFunc<>,
850 &TrueFunc<>);
851}
852
Derek Lamberti5f400d62019-03-25 15:41:58 +0000853bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
854 const TensorInfo& output,
855 Optional<std::string&> reasonIfUnsupported) const
856{
857 bool supported = true;
858
859 // Define supported output types.
860 std::array<DataType,2> supportedInputTypes = {
861 DataType::Float32,
862 };
863
864 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
865 "Reference quantize: input type not supported.");
866
867 // Define supported output types.
868 std::array<DataType,2> supportedOutputTypes = {
869 DataType::QuantisedAsymm8,
870 DataType::QuantisedSymm16
871 };
872 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
873 "Reference quantize: output type not supported.");
874
875 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
876 "Reference quantize: input and output shapes have different num total elements.");
877
878 return supported;
879}
880
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100881bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000882 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100883 Optional<std::string&> reasonIfUnsupported) const
884{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000885 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100886 return IsSupportedForDataTypeRef(reasonIfUnsupported,
887 input.GetDataType(),
888 &TrueFunc<>,
889 &TrueFunc<>);
890}
891
892bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000893 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100894 Optional<std::string&> reasonIfUnsupported) const
895{
Sadik Armaganc625f002018-12-17 11:32:16 +0000896 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100897 return IsSupportedForDataTypeRef(reasonIfUnsupported,
898 input.GetDataType(),
899 &TrueFunc<>,
900 &TrueFunc<>);
901}
902
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000903bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
904 const TensorInfo& output,
905 Optional<std::string&> reasonIfUnsupported) const
906{
907 ignore_unused(output);
908 return IsSupportedForDataTypeRef(reasonIfUnsupported,
909 input.GetDataType(),
910 &TrueFunc<>,
911 &FalseFuncU8<>);
912}
913
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100914bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
915 const TensorInfo& output,
916 const SoftmaxDescriptor& descriptor,
917 Optional<std::string&> reasonIfUnsupported) const
918{
919 ignore_unused(output);
920 ignore_unused(descriptor);
921 return IsSupportedForDataTypeRef(reasonIfUnsupported,
922 input.GetDataType(),
923 &TrueFunc<>,
924 &TrueFunc<>);
925}
926
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000927bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
928 const TensorInfo& output,
929 const SpaceToBatchNdDescriptor& descriptor,
930 Optional<std::string&> reasonIfUnsupported) const
931{
932 ignore_unused(output);
933 ignore_unused(descriptor);
934 return IsSupportedForDataTypeRef(reasonIfUnsupported,
935 input.GetDataType(),
936 &TrueFunc<>,
937 &TrueFunc<>);
938}
939
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100940bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
941 const ViewsDescriptor& descriptor,
942 Optional<std::string&> reasonIfUnsupported) const
943{
944 ignore_unused(descriptor);
945 return IsSupportedForDataTypeRef(reasonIfUnsupported,
946 input.GetDataType(),
947 &TrueFunc<>,
948 &TrueFunc<>);
949}
950
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000951bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
952 const TensorInfo& output,
953 const StridedSliceDescriptor& descriptor,
954 Optional<std::string&> reasonIfUnsupported) const
955{
956 ignore_unused(output);
957 ignore_unused(descriptor);
958 return IsSupportedForDataTypeRef(reasonIfUnsupported,
959 input.GetDataType(),
960 &TrueFunc<>,
961 &TrueFunc<>);
962}
963
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100964bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
965 const TensorInfo& input1,
966 const TensorInfo& output,
967 Optional<std::string&> reasonIfUnsupported) const
968{
Sadik Armagan2999a022019-04-09 14:20:12 +0100969 bool supported = true;
970
971 std::array<DataType,3> supportedTypes = {
972 DataType::Float32,
973 DataType::QuantisedAsymm8,
974 DataType::QuantisedSymm16
975 };
976
977 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
978 "Reference subtraction: input 0 is not a supported type.");
979
980 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
981 "Reference subtraction: input 1 is not a supported type.");
982
983 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
984 "Reference subtraction: output is not a supported type.");
985
986 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
987 "Reference subtraction: input 0 and Input 1 types are mismatched");
988
989 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
990 "Reference subtraction: input and output types are mismatched");
991
992 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
993 "Reference subtraction: shapes are not suitable for implicit broadcast.");
994
995 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100996}
997
arovir011c7c81b2018-10-08 11:34:28 +0100998} // namespace armnn