blob: 67c13c3f84ca6b5bf01820fcf130055ff8a29ebe [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010015
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
19#include <algorithm>
20#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
Derek Lamberti50db4e82019-03-13 14:16:15 +000049
50namespace
51{
52template<typename F>
53bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
54{
55 bool supported = rule();
56 if (!supported && reason)
57 {
58 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
59 }
60 return supported;
61}
62
63struct Rule
64{
65 bool operator()() const
66 {
67 return m_Res;
68 }
69
70 bool m_Res = true;
71};
72
Derek Lamberti2a434a82019-03-20 13:07:57 +000073template<typename T>
74bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000075{
76 return true;
77}
78
79template<typename T, typename... Rest>
80bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
81{
82 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
83
Derek Lamberti2a434a82019-03-20 13:07:57 +000084 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000085}
86
87struct TypesAreEqual : public Rule
88{
89 template<typename ... Ts>
90 TypesAreEqual(const Ts&... ts)
91 {
92 m_Res = AllTypesAreEqualImpl(ts...);
93 }
94};
95
96struct QuantizationParametersAreEqual : public Rule
97{
98 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
99 {
100 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
101 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
102 }
103};
104
105struct TypeAnyOf : public Rule
106{
107 template<typename Container>
108 TypeAnyOf(const TensorInfo& info, const Container& c)
109 {
110 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
111 {
112 return dt == info.GetDataType();
113 });
114 }
115};
116
117struct ShapesAreSameRank : public Rule
118{
119 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
120 {
121 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
122 }
123};
124
Derek Lamberti5f400d62019-03-25 15:41:58 +0000125struct ShapesAreSameTotalSize : public Rule
126{
127 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
128 {
129 m_Res = info0.GetNumElements() == info1.GetNumElements();
130 }
131};
132
Derek Lamberti50db4e82019-03-13 14:16:15 +0000133struct ShapesAreBroadcastCompatible : public Rule
134{
135 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
136 {
137 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
138 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
139 return sizeIn;
140 }
141
142 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
143 {
144 const TensorShape& shape0 = in0.GetShape();
145 const TensorShape& shape1 = in1.GetShape();
146 const TensorShape& outShape = out.GetShape();
147
148 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
149 {
150 unsigned int sizeOut = outShape[i];
151 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
152 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
153
154 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
155 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
156 }
157 }
158};
159} // namespace
160
161
arovir011c7c81b2018-10-08 11:34:28 +0100162bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
163 const TensorInfo& output,
164 const ActivationDescriptor& descriptor,
165 Optional<std::string&> reasonIfUnsupported) const
166{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000167 bool supported = true;
168
169 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100170 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000171 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100172 DataType::QuantisedAsymm8,
173 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000174 };
175
176 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
177 "Reference activation: input type not supported.");
178
179 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
180 "Reference activation: output type not supported.");
181
182 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
183 "Reference activation: input and output types mismatched.");
184
185 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
186 "Reference activation: input and output shapes are of different rank.");
187
188
189 struct ActivationFunctionSupported : public Rule
190 {
191 ActivationFunctionSupported(const ActivationDescriptor& desc)
192 {
193 switch(desc.m_Function)
194 {
195 case ActivationFunction::Abs:
196 case ActivationFunction::BoundedReLu:
197 case ActivationFunction::LeakyReLu:
198 case ActivationFunction::Linear:
199 case ActivationFunction::ReLu:
200 case ActivationFunction::Sigmoid:
201 case ActivationFunction::SoftReLu:
202 case ActivationFunction::Sqrt:
203 case ActivationFunction::Square:
204 case ActivationFunction::TanH:
205 {
206 m_Res = true;
207 break;
208 }
209 default:
210 {
211 m_Res = false;
212 break;
213 }
214 }
215 }
216 };
217
218 // Function is supported
219 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
220 "Reference activation: function not supported.");
221
222 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100223}
224
225bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
226 const TensorInfo& input1,
227 const TensorInfo& output,
228 Optional<std::string&> reasonIfUnsupported) const
229{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000230 bool supported = true;
231
Sadik Armagan2999a022019-04-09 14:20:12 +0100232 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000233 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100234 DataType::QuantisedAsymm8,
235 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000236 };
237
238 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
239 "Reference addition: input 0 is not a supported type.");
240
241 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
242 "Reference addition: input 1 is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
245 "Reference addition: output is not a supported type.");
246
247 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
248 "Reference addition: input 0 and Input 1 types are mismatched");
249
250 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
251 "Reference addition: input and output types are mismatched");
252
253 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
254 "Reference addition: shapes are not suitable for implicit broadcast.");
255
256 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100257}
258
259bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
260 const TensorInfo& output,
261 const TensorInfo& mean,
262 const TensorInfo& var,
263 const TensorInfo& beta,
264 const TensorInfo& gamma,
265 const BatchNormalizationDescriptor& descriptor,
266 Optional<std::string&> reasonIfUnsupported) const
267{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100268 ignore_unused(output);
269 ignore_unused(mean);
270 ignore_unused(var);
271 ignore_unused(beta);
272 ignore_unused(gamma);
273 ignore_unused(descriptor);
274 return IsSupportedForDataTypeRef(reasonIfUnsupported,
275 input.GetDataType(),
276 &TrueFunc<>,
277 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100278}
279
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000280bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const BatchToSpaceNdDescriptor& descriptor,
283 Optional<std::string&> reasonIfUnsupported) const
284{
285 ignore_unused(descriptor);
286 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
287 input.GetDataType(),
288 &TrueFunc<>,
289 &TrueFunc<>) &&
290 IsSupportedForDataTypeRef(reasonIfUnsupported,
291 output.GetDataType(),
292 &TrueFunc<>,
293 &TrueFunc<>));
294}
295
arovir011c7c81b2018-10-08 11:34:28 +0100296bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
297 Optional<std::string&> reasonIfUnsupported) const
298{
narpra01db2b1602019-01-23 15:23:11 +0000299 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
300 output.GetDataType(),
301 &FalseFunc<>,
302 &TrueFunc<>,
303 &TrueFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000304 &TrueFunc<>,
305 &FalseFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100306}
307
308bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
309 const TensorInfo& output,
310 Optional<std::string&> reasonIfUnsupported) const
311{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100312 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
313 input.GetDataType(),
314 &TrueFunc<>,
315 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000316 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000317 &FalseFuncI32<>,
318 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100319 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
320 output.GetDataType(),
321 &FalseOutputFuncF16<>,
322 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000323 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000324 &FalseFuncI32<>,
325 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100326}
327
328bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
329 const TensorInfo& output,
330 Optional<std::string&> reasonIfUnsupported) const
331{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100332 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
333 input.GetDataType(),
334 &FalseInputFuncF16<>,
335 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000336 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000337 &FalseFuncI32<>,
338 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100339 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
340 output.GetDataType(),
341 &TrueFunc<>,
342 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000343 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000344 &FalseFuncI32<>,
345 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100346}
347
348bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
349 const TensorInfo& output,
350 const Convolution2dDescriptor& descriptor,
351 const TensorInfo& weights,
352 const Optional<TensorInfo>& biases,
353 Optional<std::string&> reasonIfUnsupported) const
354{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100355 ignore_unused(output);
356 ignore_unused(descriptor);
357 ignore_unused(weights);
358 ignore_unused(biases);
359 return IsSupportedForDataTypeRef(reasonIfUnsupported,
360 input.GetDataType(),
361 &TrueFunc<>,
362 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100363}
364
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000365bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
366 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000367 Optional<std::string&> reasonIfUnsupported) const
368{
369 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000370 return IsSupportedForDataTypeRef(reasonIfUnsupported,
371 input.GetDataType(),
372 &TrueFunc<>,
373 &TrueFunc<>);
374}
375
arovir011c7c81b2018-10-08 11:34:28 +0100376bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
377 const TensorInfo& output,
378 const DepthwiseConvolution2dDescriptor& descriptor,
379 const TensorInfo& weights,
380 const Optional<TensorInfo>& biases,
381 Optional<std::string&> reasonIfUnsupported) const
382{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100383 ignore_unused(output);
384 ignore_unused(descriptor);
385 ignore_unused(weights);
386 ignore_unused(biases);
387 return IsSupportedForDataTypeRef(reasonIfUnsupported,
388 input.GetDataType(),
389 &TrueFunc<>,
390 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100391}
392
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000393bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
394 const TensorInfo& output,
395 Optional<std::string&> reasonIfUnsupported) const
396{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100397 bool supported = true;
398
399 std::array<DataType,2> supportedInputTypes = {
400 DataType::QuantisedAsymm8,
401 DataType::QuantisedSymm16
402 };
403
404 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
405 "Reference dequantize: input type not supported.");
406
407 std::array<DataType,2> supportedOutputTypes = {
408 DataType::Float32,
409 };
410
411 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
412 "Reference dequantize: output type not supported.");
413
414 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
415 "Reference dequantize: input and output shapes have different num total elements.");
416
417 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000418}
419
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000420bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
421 const armnn::TensorInfo& input1,
422 const armnn::DetectionPostProcessDescriptor& descriptor,
423 armnn::Optional<std::string&> reasonIfUnsupported) const
424{
425 ignore_unused(input1);
426 return IsSupportedForDataTypeRef(reasonIfUnsupported,
427 input0.GetDataType(),
428 &TrueFunc<>,
429 &TrueFunc<>);
430}
431
arovir011c7c81b2018-10-08 11:34:28 +0100432bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
433 const TensorInfo& input1,
434 const TensorInfo& output,
435 Optional<std::string&> reasonIfUnsupported) const
436{
Sadik Armagan2999a022019-04-09 14:20:12 +0100437 bool supported = true;
438
439 std::array<DataType,3> supportedTypes = {
440 DataType::Float32,
441 DataType::QuantisedAsymm8,
442 DataType::QuantisedSymm16
443 };
444
445 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
446 "Reference division: input 0 is not a supported type.");
447
448 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
449 "Reference division: input 1 is not a supported type.");
450
451 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
452 "Reference division: output is not a supported type.");
453
454 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
455 "Reference division: input 0 and Input 1 types are mismatched");
456
457 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
458 "Reference division: input and output types are mismatched");
459
460 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
461 "Reference division: shapes are not suitable for implicit broadcast.");
462
463 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100464}
465
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000466bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
467 const TensorInfo& input1,
468 const TensorInfo& output,
469 Optional<std::string&> reasonIfUnsupported) const
470{
471 ignore_unused(input0);
472 ignore_unused(input1);
473 ignore_unused(output);
474 ignore_unused(reasonIfUnsupported);
475 return IsSupportedForDataTypeRef(reasonIfUnsupported,
476 input0.GetDataType(),
477 &TrueFunc<>,
478 &TrueFunc<>);
479}
480
arovir011c7c81b2018-10-08 11:34:28 +0100481bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
482 const FakeQuantizationDescriptor& descriptor,
483 Optional<std::string&> reasonIfUnsupported) const
484{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100485 ignore_unused(descriptor);
486 return IsSupportedForDataTypeRef(reasonIfUnsupported,
487 input.GetDataType(),
488 &TrueFunc<>,
489 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100490}
491
492bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
493 const TensorInfo& output,
494 Optional<std::string&> reasonIfUnsupported) const
495{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100496 ignore_unused(output);
497 return IsSupportedForDataTypeRef(reasonIfUnsupported,
498 input.GetDataType(),
499 &TrueFunc<>,
500 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100501}
502
503bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
504 const TensorInfo& output,
505 const TensorInfo& weights,
506 const TensorInfo& biases,
507 const FullyConnectedDescriptor& descriptor,
508 Optional<std::string&> reasonIfUnsupported) const
509{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100510 ignore_unused(output);
511 ignore_unused(weights);
512 ignore_unused(biases);
513 ignore_unused(descriptor);
514 return IsSupportedForDataTypeRef(reasonIfUnsupported,
515 input.GetDataType(),
516 &TrueFunc<>,
517 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100518}
519
narpra014951d842019-01-18 16:53:53 +0000520bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
521 const armnn::TensorInfo& input1,
522 const armnn::TensorInfo& output,
523 armnn::Optional<std::string&> reasonIfUnsupported) const
524{
525 ignore_unused(input1);
526 ignore_unused(output);
527 return IsSupportedForDataTypeRef(reasonIfUnsupported,
528 input0.GetDataType(),
529 &TrueFunc<>,
530 &TrueFunc<>);
531}
532
FrancisMurtagh878f0232018-12-19 10:56:15 +0000533bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
534 const TensorInfo& input1,
535 const TensorInfo& output,
536 Optional<std::string&> reasonIfUnsupported) const
537{
538 ignore_unused(input0);
539 ignore_unused(input1);
540 ignore_unused(output);
541 ignore_unused(reasonIfUnsupported);
542 return IsSupportedForDataTypeRef(reasonIfUnsupported,
543 input0.GetDataType(),
544 &TrueFunc<>,
545 &TrueFunc<>);
546}
547
arovir011c7c81b2018-10-08 11:34:28 +0100548bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
549 Optional<std::string&> reasonIfUnsupported) const
550{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100551 return IsSupportedForDataTypeRef(reasonIfUnsupported,
552 input.GetDataType(),
553 &TrueFunc<>,
554 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100555}
556
557bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
558 const TensorInfo& output,
559 const L2NormalizationDescriptor& descriptor,
560 Optional<std::string&> reasonIfUnsupported) const
561{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100562 ignore_unused(output);
563 ignore_unused(descriptor);
564 return IsSupportedForDataTypeRef(reasonIfUnsupported,
565 input.GetDataType(),
566 &TrueFunc<>,
567 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100568}
569
570bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
571 const TensorInfo& outputStateIn,
572 const TensorInfo& cellStateIn,
573 const TensorInfo& scratchBuffer,
574 const TensorInfo& outputStateOut,
575 const TensorInfo& cellStateOut,
576 const TensorInfo& output,
577 const LstmDescriptor& descriptor,
578 const TensorInfo& inputToForgetWeights,
579 const TensorInfo& inputToCellWeights,
580 const TensorInfo& inputToOutputWeights,
581 const TensorInfo& recurrentToForgetWeights,
582 const TensorInfo& recurrentToCellWeights,
583 const TensorInfo& recurrentToOutputWeights,
584 const TensorInfo& forgetGateBias,
585 const TensorInfo& cellBias,
586 const TensorInfo& outputGateBias,
587 const TensorInfo* inputToInputWeights,
588 const TensorInfo* recurrentToInputWeights,
589 const TensorInfo* cellToInputWeights,
590 const TensorInfo* inputGateBias,
591 const TensorInfo* projectionWeights,
592 const TensorInfo* projectionBias,
593 const TensorInfo* cellToForgetWeights,
594 const TensorInfo* cellToOutputWeights,
595 Optional<std::string&> reasonIfUnsupported) const
596{
telsoa01c577f2c2018-08-31 09:22:23 +0100597 ignore_unused(descriptor);
598 ignore_unused(inputToForgetWeights);
599 ignore_unused(inputToCellWeights);
600 ignore_unused(inputToOutputWeights);
601 ignore_unused(recurrentToForgetWeights);
602 ignore_unused(recurrentToCellWeights);
603 ignore_unused(recurrentToOutputWeights);
604 ignore_unused(forgetGateBias);
605 ignore_unused(cellBias);
606 ignore_unused(outputGateBias);
607 ignore_unused(inputToInputWeights);
608 ignore_unused(recurrentToInputWeights);
609 ignore_unused(cellToInputWeights);
610 ignore_unused(inputGateBias);
611 ignore_unused(projectionWeights);
612 ignore_unused(projectionBias);
613 ignore_unused(cellToForgetWeights);
614 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100615
616 bool supported = true;
617
618 std::array<DataType,2> supportedTypes = {
619 DataType::Float32
620 };
621
622 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
623 "Reference Lstm: input is not a supported type.");
624
625 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
626 "Reference Lstm: input and outputStateIn types are mismatched");
627
628 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
629 "Reference Lstm: input and cellStateIn types are mismatched");
630
631 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
632 "Reference Lstm: input and scratchBuffer types are mismatched");
633
634 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
635 "Reference Lstm: input and outputStateOut types are mismatched");
636
637 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
638 "Reference Lstm: input and cellStateOut types are mismatched");
639
640 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
641 "Reference Lstm: input and output types are mismatched");
642
643 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100644}
645
saoste012df12b32018-11-28 16:57:20 +0000646bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
647 const TensorInfo& input1,
648 const TensorInfo& output,
649 Optional<std::string&> reasonIfUnsupported) const
650{
Sadik Armagan2999a022019-04-09 14:20:12 +0100651 bool supported = true;
652
653 std::array<DataType,3> supportedTypes = {
654 DataType::Float32,
655 DataType::QuantisedAsymm8,
656 DataType::QuantisedSymm16
657 };
658
659 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
660 "Reference maximum: input 0 is not a supported type.");
661
662 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
663 "Reference maximum: input 1 is not a supported type.");
664
665 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
666 "Reference maximum: output is not a supported type.");
667
668 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
669 "Reference maximum: input 0 and Input 1 types are mismatched");
670
671 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
672 "Reference maximum: input and output types are mismatched");
673
674 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
675 "Reference maximum: shapes are not suitable for implicit broadcast.");
676
677 return supported;
saoste012df12b32018-11-28 16:57:20 +0000678}
679
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100680bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
681 const TensorInfo& output,
682 const MeanDescriptor& descriptor,
683 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100684{
narpra011e4c31d2018-09-28 11:07:51 +0100685 ignore_unused(output);
686 ignore_unused(descriptor);
687 return IsSupportedForDataTypeRef(reasonIfUnsupported,
688 input.GetDataType(),
689 &TrueFunc<>,
690 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100691}
692
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100693bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000694 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100695 const OriginsDescriptor& descriptor,
696 Optional<std::string&> reasonIfUnsupported) const
697{
698 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000699 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100700 return IsSupportedForDataTypeRef(reasonIfUnsupported,
701 inputs[0]->GetDataType(),
702 &TrueFunc<>,
703 &TrueFunc<>);
704}
705
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000706bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
707 const TensorInfo &output,
708 Optional<std::string &> reasonIfUnsupported) const
709{
710 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000711 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
712 input.GetDataType(),
713 &TrueFunc<>,
714 &TrueFunc<>,
715 &TrueFunc<>,
716 &FalseFuncI32<>,
717 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000718}
719
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000720bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
721 const TensorInfo& input1,
722 const TensorInfo& output,
723 Optional<std::string&> reasonIfUnsupported) const
724{
Sadik Armagan2999a022019-04-09 14:20:12 +0100725 bool supported = true;
726
727 std::array<DataType,3> supportedTypes = {
728 DataType::Float32,
729 DataType::QuantisedAsymm8,
730 DataType::QuantisedSymm16
731 };
732
733 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
734 "Reference minimum: input 0 is not a supported type.");
735
736 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
737 "Reference minimum: input 1 is not a supported type.");
738
739 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
740 "Reference minimum: output is not a supported type.");
741
742 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
743 "Reference minimum: input 0 and Input 1 types are mismatched");
744
745 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
746 "Reference minimum: input and output types are mismatched");
747
748 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
749 "Reference minimum: shapes are not suitable for implicit broadcast.");
750
751 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000752}
753
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100754bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
755 const TensorInfo& input1,
756 const TensorInfo& output,
757 Optional<std::string&> reasonIfUnsupported) const
758{
Sadik Armagan2999a022019-04-09 14:20:12 +0100759 bool supported = true;
760
761 std::array<DataType,3> supportedTypes = {
762 DataType::Float32,
763 DataType::QuantisedAsymm8,
764 DataType::QuantisedSymm16
765 };
766
767 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
768 "Reference multiplication: input 0 is not a supported type.");
769
770 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
771 "Reference multiplication: input 1 is not a supported type.");
772
773 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
774 "Reference multiplication: output is not a supported type.");
775
776 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
777 "Reference multiplication: input 0 and Input 1 types are mismatched");
778
779 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
780 "Reference multiplication: input and output types are mismatched");
781
782 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
783 "Reference multiplication: shapes are not suitable for implicit broadcast.");
784
785 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100786}
787
788bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
789 const TensorInfo& output,
790 const NormalizationDescriptor& descriptor,
791 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100792{
793 ignore_unused(output);
794 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100795 return IsSupportedForDataTypeRef(reasonIfUnsupported,
796 input.GetDataType(),
797 &TrueFunc<>,
798 &FalseFuncU8<>);
799}
800
801bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
802 Optional<std::string&> reasonIfUnsupported) const
803{
kevmay012b4d88e2019-01-24 14:05:09 +0000804 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
805 output.GetDataType(),
806 &TrueFunc<>,
807 &TrueFunc<>,
808 &TrueFunc<>,
809 &FalseFuncI32<>,
810 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100811}
812
813bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
814 const TensorInfo& output,
815 const PadDescriptor& descriptor,
816 Optional<std::string&> reasonIfUnsupported) const
817{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100818 ignore_unused(output);
819 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000820 return IsSupportedForDataTypeRef(reasonIfUnsupported,
821 input.GetDataType(),
822 &TrueFunc<>,
823 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100824}
825
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100826bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
827 const TensorInfo& output,
828 const PermuteDescriptor& descriptor,
829 Optional<std::string&> reasonIfUnsupported) const
830{
831 ignore_unused(output);
832 ignore_unused(descriptor);
833 return IsSupportedForDataTypeRef(reasonIfUnsupported,
834 input.GetDataType(),
835 &TrueFunc<>,
836 &TrueFunc<>);
837}
838
839bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
840 const TensorInfo& output,
841 const Pooling2dDescriptor& descriptor,
842 Optional<std::string&> reasonIfUnsupported) const
843{
844 ignore_unused(output);
845 ignore_unused(descriptor);
846 return IsSupportedForDataTypeRef(reasonIfUnsupported,
847 input.GetDataType(),
848 &TrueFunc<>,
849 &TrueFunc<>);
850}
851
Derek Lamberti5f400d62019-03-25 15:41:58 +0000852bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
853 const TensorInfo& output,
854 Optional<std::string&> reasonIfUnsupported) const
855{
856 bool supported = true;
857
858 // Define supported output types.
859 std::array<DataType,2> supportedInputTypes = {
860 DataType::Float32,
861 };
862
863 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
864 "Reference quantize: input type not supported.");
865
866 // Define supported output types.
867 std::array<DataType,2> supportedOutputTypes = {
868 DataType::QuantisedAsymm8,
869 DataType::QuantisedSymm16
870 };
871 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
872 "Reference quantize: output type not supported.");
873
874 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
875 "Reference quantize: input and output shapes have different num total elements.");
876
877 return supported;
878}
879
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100880bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000881 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100882 Optional<std::string&> reasonIfUnsupported) const
883{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000884 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100885 return IsSupportedForDataTypeRef(reasonIfUnsupported,
886 input.GetDataType(),
887 &TrueFunc<>,
888 &TrueFunc<>);
889}
890
891bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000892 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100893 Optional<std::string&> reasonIfUnsupported) const
894{
Sadik Armaganc625f002018-12-17 11:32:16 +0000895 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100896 return IsSupportedForDataTypeRef(reasonIfUnsupported,
897 input.GetDataType(),
898 &TrueFunc<>,
899 &TrueFunc<>);
900}
901
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000902bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
903 const TensorInfo& output,
904 Optional<std::string&> reasonIfUnsupported) const
905{
906 ignore_unused(output);
907 return IsSupportedForDataTypeRef(reasonIfUnsupported,
908 input.GetDataType(),
909 &TrueFunc<>,
910 &FalseFuncU8<>);
911}
912
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100913bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
914 const TensorInfo& output,
915 const SoftmaxDescriptor& descriptor,
916 Optional<std::string&> reasonIfUnsupported) const
917{
918 ignore_unused(output);
919 ignore_unused(descriptor);
920 return IsSupportedForDataTypeRef(reasonIfUnsupported,
921 input.GetDataType(),
922 &TrueFunc<>,
923 &TrueFunc<>);
924}
925
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000926bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
927 const TensorInfo& output,
928 const SpaceToBatchNdDescriptor& descriptor,
929 Optional<std::string&> reasonIfUnsupported) const
930{
931 ignore_unused(output);
932 ignore_unused(descriptor);
933 return IsSupportedForDataTypeRef(reasonIfUnsupported,
934 input.GetDataType(),
935 &TrueFunc<>,
936 &TrueFunc<>);
937}
938
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100939bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
940 const ViewsDescriptor& descriptor,
941 Optional<std::string&> reasonIfUnsupported) const
942{
943 ignore_unused(descriptor);
944 return IsSupportedForDataTypeRef(reasonIfUnsupported,
945 input.GetDataType(),
946 &TrueFunc<>,
947 &TrueFunc<>);
948}
949
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000950bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
951 const TensorInfo& output,
952 const StridedSliceDescriptor& descriptor,
953 Optional<std::string&> reasonIfUnsupported) const
954{
955 ignore_unused(output);
956 ignore_unused(descriptor);
957 return IsSupportedForDataTypeRef(reasonIfUnsupported,
958 input.GetDataType(),
959 &TrueFunc<>,
960 &TrueFunc<>);
961}
962
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100963bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
964 const TensorInfo& input1,
965 const TensorInfo& output,
966 Optional<std::string&> reasonIfUnsupported) const
967{
Sadik Armagan2999a022019-04-09 14:20:12 +0100968 bool supported = true;
969
970 std::array<DataType,3> supportedTypes = {
971 DataType::Float32,
972 DataType::QuantisedAsymm8,
973 DataType::QuantisedSymm16
974 };
975
976 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
977 "Reference subtraction: input 0 is not a supported type.");
978
979 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
980 "Reference subtraction: input 1 is not a supported type.");
981
982 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
983 "Reference subtraction: output is not a supported type.");
984
985 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
986 "Reference subtraction: input 0 and Input 1 types are mismatched");
987
988 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
989 "Reference subtraction: input and output types are mismatched");
990
991 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
992 "Reference subtraction: shapes are not suitable for implicit broadcast.");
993
994 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100995}
996
arovir011c7c81b2018-10-08 11:34:28 +0100997} // namespace armnn