blob: 1b1f0ce1c6405236d14f6b113454d1f749559a09 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010015
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
19#include <algorithm>
20#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
Derek Lamberti50db4e82019-03-13 14:16:15 +000049
50namespace
51{
52template<typename F>
53bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
54{
55 bool supported = rule();
56 if (!supported && reason)
57 {
58 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
59 }
60 return supported;
61}
62
63struct Rule
64{
65 bool operator()() const
66 {
67 return m_Res;
68 }
69
70 bool m_Res = true;
71};
72
Derek Lamberti2a434a82019-03-20 13:07:57 +000073template<typename T>
74bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000075{
76 return true;
77}
78
79template<typename T, typename... Rest>
80bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
81{
82 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
83
Derek Lamberti2a434a82019-03-20 13:07:57 +000084 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000085}
86
87struct TypesAreEqual : public Rule
88{
89 template<typename ... Ts>
90 TypesAreEqual(const Ts&... ts)
91 {
92 m_Res = AllTypesAreEqualImpl(ts...);
93 }
94};
95
96struct QuantizationParametersAreEqual : public Rule
97{
98 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
99 {
100 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
101 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
102 }
103};
104
105struct TypeAnyOf : public Rule
106{
107 template<typename Container>
108 TypeAnyOf(const TensorInfo& info, const Container& c)
109 {
110 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
111 {
112 return dt == info.GetDataType();
113 });
114 }
115};
116
117struct ShapesAreSameRank : public Rule
118{
119 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
120 {
121 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
122 }
123};
124
Derek Lamberti5f400d62019-03-25 15:41:58 +0000125struct ShapesAreSameTotalSize : public Rule
126{
127 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
128 {
129 m_Res = info0.GetNumElements() == info1.GetNumElements();
130 }
131};
132
Derek Lamberti50db4e82019-03-13 14:16:15 +0000133struct ShapesAreBroadcastCompatible : public Rule
134{
135 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
136 {
137 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
138 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
139 return sizeIn;
140 }
141
142 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
143 {
144 const TensorShape& shape0 = in0.GetShape();
145 const TensorShape& shape1 = in1.GetShape();
146 const TensorShape& outShape = out.GetShape();
147
148 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
149 {
150 unsigned int sizeOut = outShape[i];
151 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
152 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
153
154 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
155 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
156 }
157 }
158};
159} // namespace
160
161
arovir011c7c81b2018-10-08 11:34:28 +0100162bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
163 const TensorInfo& output,
164 const ActivationDescriptor& descriptor,
165 Optional<std::string&> reasonIfUnsupported) const
166{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000167 bool supported = true;
168
169 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100170 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000171 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100172 DataType::QuantisedAsymm8,
173 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000174 };
175
176 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
177 "Reference activation: input type not supported.");
178
179 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
180 "Reference activation: output type not supported.");
181
182 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
183 "Reference activation: input and output types mismatched.");
184
185 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
186 "Reference activation: input and output shapes are of different rank.");
187
188
189 struct ActivationFunctionSupported : public Rule
190 {
191 ActivationFunctionSupported(const ActivationDescriptor& desc)
192 {
193 switch(desc.m_Function)
194 {
195 case ActivationFunction::Abs:
196 case ActivationFunction::BoundedReLu:
197 case ActivationFunction::LeakyReLu:
198 case ActivationFunction::Linear:
199 case ActivationFunction::ReLu:
200 case ActivationFunction::Sigmoid:
201 case ActivationFunction::SoftReLu:
202 case ActivationFunction::Sqrt:
203 case ActivationFunction::Square:
204 case ActivationFunction::TanH:
205 {
206 m_Res = true;
207 break;
208 }
209 default:
210 {
211 m_Res = false;
212 break;
213 }
214 }
215 }
216 };
217
218 // Function is supported
219 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
220 "Reference activation: function not supported.");
221
222 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100223}
224
225bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
226 const TensorInfo& input1,
227 const TensorInfo& output,
228 Optional<std::string&> reasonIfUnsupported) const
229{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000230 bool supported = true;
231
Sadik Armagan2999a022019-04-09 14:20:12 +0100232 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000233 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100234 DataType::QuantisedAsymm8,
235 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000236 };
237
238 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
239 "Reference addition: input 0 is not a supported type.");
240
241 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
242 "Reference addition: input 1 is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
245 "Reference addition: output is not a supported type.");
246
247 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
248 "Reference addition: input 0 and Input 1 types are mismatched");
249
250 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
251 "Reference addition: input and output types are mismatched");
252
253 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
254 "Reference addition: shapes are not suitable for implicit broadcast.");
255
256 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100257}
258
259bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
260 const TensorInfo& output,
261 const TensorInfo& mean,
262 const TensorInfo& var,
263 const TensorInfo& beta,
264 const TensorInfo& gamma,
265 const BatchNormalizationDescriptor& descriptor,
266 Optional<std::string&> reasonIfUnsupported) const
267{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100268 ignore_unused(output);
269 ignore_unused(mean);
270 ignore_unused(var);
271 ignore_unused(beta);
272 ignore_unused(gamma);
273 ignore_unused(descriptor);
274 return IsSupportedForDataTypeRef(reasonIfUnsupported,
275 input.GetDataType(),
276 &TrueFunc<>,
277 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100278}
279
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000280bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const BatchToSpaceNdDescriptor& descriptor,
283 Optional<std::string&> reasonIfUnsupported) const
284{
285 ignore_unused(descriptor);
286 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
287 input.GetDataType(),
288 &TrueFunc<>,
289 &TrueFunc<>) &&
290 IsSupportedForDataTypeRef(reasonIfUnsupported,
291 output.GetDataType(),
292 &TrueFunc<>,
293 &TrueFunc<>));
294}
295
arovir011c7c81b2018-10-08 11:34:28 +0100296bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
297 Optional<std::string&> reasonIfUnsupported) const
298{
narpra01db2b1602019-01-23 15:23:11 +0000299 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
300 output.GetDataType(),
301 &FalseFunc<>,
302 &TrueFunc<>,
303 &TrueFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000304 &TrueFunc<>,
305 &FalseFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100306}
307
308bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
309 const TensorInfo& output,
310 Optional<std::string&> reasonIfUnsupported) const
311{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100312 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
313 input.GetDataType(),
314 &TrueFunc<>,
315 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000316 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000317 &FalseFuncI32<>,
318 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100319 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
320 output.GetDataType(),
321 &FalseOutputFuncF16<>,
322 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000323 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000324 &FalseFuncI32<>,
325 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100326}
327
328bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
329 const TensorInfo& output,
330 Optional<std::string&> reasonIfUnsupported) const
331{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100332 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
333 input.GetDataType(),
334 &FalseInputFuncF16<>,
335 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000336 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000337 &FalseFuncI32<>,
338 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100339 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
340 output.GetDataType(),
341 &TrueFunc<>,
342 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000343 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000344 &FalseFuncI32<>,
345 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100346}
347
348bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
349 const TensorInfo& output,
350 const Convolution2dDescriptor& descriptor,
351 const TensorInfo& weights,
352 const Optional<TensorInfo>& biases,
353 Optional<std::string&> reasonIfUnsupported) const
354{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100355 ignore_unused(output);
356 ignore_unused(descriptor);
357 ignore_unused(weights);
358 ignore_unused(biases);
359 return IsSupportedForDataTypeRef(reasonIfUnsupported,
360 input.GetDataType(),
361 &TrueFunc<>,
362 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100363}
364
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000365bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
366 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000367 Optional<std::string&> reasonIfUnsupported) const
368{
369 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000370 return IsSupportedForDataTypeRef(reasonIfUnsupported,
371 input.GetDataType(),
372 &TrueFunc<>,
373 &TrueFunc<>);
374}
375
arovir011c7c81b2018-10-08 11:34:28 +0100376bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
377 const TensorInfo& output,
378 const DepthwiseConvolution2dDescriptor& descriptor,
379 const TensorInfo& weights,
380 const Optional<TensorInfo>& biases,
381 Optional<std::string&> reasonIfUnsupported) const
382{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100383 ignore_unused(output);
384 ignore_unused(descriptor);
385 ignore_unused(weights);
386 ignore_unused(biases);
387 return IsSupportedForDataTypeRef(reasonIfUnsupported,
388 input.GetDataType(),
389 &TrueFunc<>,
390 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100391}
392
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000393bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
394 const TensorInfo& output,
395 Optional<std::string&> reasonIfUnsupported) const
396{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100397 bool supported = true;
398
399 std::array<DataType,2> supportedInputTypes = {
400 DataType::QuantisedAsymm8,
401 DataType::QuantisedSymm16
402 };
403
404 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
405 "Reference dequantize: input type not supported.");
406
407 std::array<DataType,2> supportedOutputTypes = {
408 DataType::Float32,
409 };
410
411 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
412 "Reference dequantize: output type not supported.");
413
414 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
415 "Reference dequantize: input and output shapes have different num total elements.");
416
417 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000418}
419
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000420bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
421 const armnn::TensorInfo& input1,
422 const armnn::DetectionPostProcessDescriptor& descriptor,
423 armnn::Optional<std::string&> reasonIfUnsupported) const
424{
425 ignore_unused(input1);
426 return IsSupportedForDataTypeRef(reasonIfUnsupported,
427 input0.GetDataType(),
428 &TrueFunc<>,
429 &TrueFunc<>);
430}
431
arovir011c7c81b2018-10-08 11:34:28 +0100432bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
433 const TensorInfo& input1,
434 const TensorInfo& output,
435 Optional<std::string&> reasonIfUnsupported) const
436{
Sadik Armagan2999a022019-04-09 14:20:12 +0100437 bool supported = true;
438
439 std::array<DataType,3> supportedTypes = {
440 DataType::Float32,
441 DataType::QuantisedAsymm8,
442 DataType::QuantisedSymm16
443 };
444
445 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
446 "Reference division: input 0 is not a supported type.");
447
448 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
449 "Reference division: input 1 is not a supported type.");
450
451 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
452 "Reference division: output is not a supported type.");
453
454 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
455 "Reference division: input 0 and Input 1 types are mismatched");
456
457 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
458 "Reference division: input and output types are mismatched");
459
460 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
461 "Reference division: shapes are not suitable for implicit broadcast.");
462
463 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100464}
465
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000466bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
467 const TensorInfo& input1,
468 const TensorInfo& output,
469 Optional<std::string&> reasonIfUnsupported) const
470{
471 ignore_unused(input0);
472 ignore_unused(input1);
473 ignore_unused(output);
474 ignore_unused(reasonIfUnsupported);
475 return IsSupportedForDataTypeRef(reasonIfUnsupported,
476 input0.GetDataType(),
477 &TrueFunc<>,
478 &TrueFunc<>);
479}
480
arovir011c7c81b2018-10-08 11:34:28 +0100481bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
482 const FakeQuantizationDescriptor& descriptor,
483 Optional<std::string&> reasonIfUnsupported) const
484{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100485 ignore_unused(descriptor);
486 return IsSupportedForDataTypeRef(reasonIfUnsupported,
487 input.GetDataType(),
488 &TrueFunc<>,
489 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100490}
491
492bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
493 const TensorInfo& output,
494 Optional<std::string&> reasonIfUnsupported) const
495{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100496 ignore_unused(output);
497 return IsSupportedForDataTypeRef(reasonIfUnsupported,
498 input.GetDataType(),
499 &TrueFunc<>,
500 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100501}
502
503bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
504 const TensorInfo& output,
505 const TensorInfo& weights,
506 const TensorInfo& biases,
507 const FullyConnectedDescriptor& descriptor,
508 Optional<std::string&> reasonIfUnsupported) const
509{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100510 ignore_unused(output);
511 ignore_unused(weights);
512 ignore_unused(biases);
513 ignore_unused(descriptor);
514 return IsSupportedForDataTypeRef(reasonIfUnsupported,
515 input.GetDataType(),
516 &TrueFunc<>,
517 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100518}
519
narpra014951d842019-01-18 16:53:53 +0000520bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
521 const armnn::TensorInfo& input1,
522 const armnn::TensorInfo& output,
523 armnn::Optional<std::string&> reasonIfUnsupported) const
524{
525 ignore_unused(input1);
526 ignore_unused(output);
527 return IsSupportedForDataTypeRef(reasonIfUnsupported,
528 input0.GetDataType(),
529 &TrueFunc<>,
530 &TrueFunc<>);
531}
532
FrancisMurtagh878f0232018-12-19 10:56:15 +0000533bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
534 const TensorInfo& input1,
535 const TensorInfo& output,
536 Optional<std::string&> reasonIfUnsupported) const
537{
538 ignore_unused(input0);
539 ignore_unused(input1);
540 ignore_unused(output);
541 ignore_unused(reasonIfUnsupported);
542 return IsSupportedForDataTypeRef(reasonIfUnsupported,
543 input0.GetDataType(),
544 &TrueFunc<>,
545 &TrueFunc<>);
546}
547
arovir011c7c81b2018-10-08 11:34:28 +0100548bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
549 Optional<std::string&> reasonIfUnsupported) const
550{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100551 return IsSupportedForDataTypeRef(reasonIfUnsupported,
552 input.GetDataType(),
553 &TrueFunc<>,
554 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100555}
556
557bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
558 const TensorInfo& output,
559 const L2NormalizationDescriptor& descriptor,
560 Optional<std::string&> reasonIfUnsupported) const
561{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100562 ignore_unused(output);
563 ignore_unused(descriptor);
564 return IsSupportedForDataTypeRef(reasonIfUnsupported,
565 input.GetDataType(),
566 &TrueFunc<>,
567 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100568}
569
570bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
571 const TensorInfo& outputStateIn,
572 const TensorInfo& cellStateIn,
573 const TensorInfo& scratchBuffer,
574 const TensorInfo& outputStateOut,
575 const TensorInfo& cellStateOut,
576 const TensorInfo& output,
577 const LstmDescriptor& descriptor,
578 const TensorInfo& inputToForgetWeights,
579 const TensorInfo& inputToCellWeights,
580 const TensorInfo& inputToOutputWeights,
581 const TensorInfo& recurrentToForgetWeights,
582 const TensorInfo& recurrentToCellWeights,
583 const TensorInfo& recurrentToOutputWeights,
584 const TensorInfo& forgetGateBias,
585 const TensorInfo& cellBias,
586 const TensorInfo& outputGateBias,
587 const TensorInfo* inputToInputWeights,
588 const TensorInfo* recurrentToInputWeights,
589 const TensorInfo* cellToInputWeights,
590 const TensorInfo* inputGateBias,
591 const TensorInfo* projectionWeights,
592 const TensorInfo* projectionBias,
593 const TensorInfo* cellToForgetWeights,
594 const TensorInfo* cellToOutputWeights,
595 Optional<std::string&> reasonIfUnsupported) const
596{
telsoa01c577f2c2018-08-31 09:22:23 +0100597 ignore_unused(outputStateIn);
598 ignore_unused(cellStateIn);
599 ignore_unused(scratchBuffer);
600 ignore_unused(outputStateOut);
601 ignore_unused(cellStateOut);
602 ignore_unused(output);
603 ignore_unused(descriptor);
604 ignore_unused(inputToForgetWeights);
605 ignore_unused(inputToCellWeights);
606 ignore_unused(inputToOutputWeights);
607 ignore_unused(recurrentToForgetWeights);
608 ignore_unused(recurrentToCellWeights);
609 ignore_unused(recurrentToOutputWeights);
610 ignore_unused(forgetGateBias);
611 ignore_unused(cellBias);
612 ignore_unused(outputGateBias);
613 ignore_unused(inputToInputWeights);
614 ignore_unused(recurrentToInputWeights);
615 ignore_unused(cellToInputWeights);
616 ignore_unused(inputGateBias);
617 ignore_unused(projectionWeights);
618 ignore_unused(projectionBias);
619 ignore_unused(cellToForgetWeights);
620 ignore_unused(cellToOutputWeights);
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000621 return IsSupportedForDataTypeRef(reasonIfUnsupported,
622 input.GetDataType(),
623 &TrueFunc<>,
624 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100625}
626
saoste012df12b32018-11-28 16:57:20 +0000627bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
628 const TensorInfo& input1,
629 const TensorInfo& output,
630 Optional<std::string&> reasonIfUnsupported) const
631{
Sadik Armagan2999a022019-04-09 14:20:12 +0100632 bool supported = true;
633
634 std::array<DataType,3> supportedTypes = {
635 DataType::Float32,
636 DataType::QuantisedAsymm8,
637 DataType::QuantisedSymm16
638 };
639
640 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
641 "Reference maximum: input 0 is not a supported type.");
642
643 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
644 "Reference maximum: input 1 is not a supported type.");
645
646 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
647 "Reference maximum: output is not a supported type.");
648
649 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
650 "Reference maximum: input 0 and Input 1 types are mismatched");
651
652 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
653 "Reference maximum: input and output types are mismatched");
654
655 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
656 "Reference maximum: shapes are not suitable for implicit broadcast.");
657
658 return supported;
saoste012df12b32018-11-28 16:57:20 +0000659}
660
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100661bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
662 const TensorInfo& output,
663 const MeanDescriptor& descriptor,
664 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100665{
narpra011e4c31d2018-09-28 11:07:51 +0100666 ignore_unused(output);
667 ignore_unused(descriptor);
668 return IsSupportedForDataTypeRef(reasonIfUnsupported,
669 input.GetDataType(),
670 &TrueFunc<>,
671 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100672}
673
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100674bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000675 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100676 const OriginsDescriptor& descriptor,
677 Optional<std::string&> reasonIfUnsupported) const
678{
679 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000680 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100681 return IsSupportedForDataTypeRef(reasonIfUnsupported,
682 inputs[0]->GetDataType(),
683 &TrueFunc<>,
684 &TrueFunc<>);
685}
686
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000687bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
688 const TensorInfo &output,
689 Optional<std::string &> reasonIfUnsupported) const
690{
691 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000692 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
693 input.GetDataType(),
694 &TrueFunc<>,
695 &TrueFunc<>,
696 &TrueFunc<>,
697 &FalseFuncI32<>,
698 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000699}
700
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000701bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
702 const TensorInfo& input1,
703 const TensorInfo& output,
704 Optional<std::string&> reasonIfUnsupported) const
705{
Sadik Armagan2999a022019-04-09 14:20:12 +0100706 bool supported = true;
707
708 std::array<DataType,3> supportedTypes = {
709 DataType::Float32,
710 DataType::QuantisedAsymm8,
711 DataType::QuantisedSymm16
712 };
713
714 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
715 "Reference minimum: input 0 is not a supported type.");
716
717 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
718 "Reference minimum: input 1 is not a supported type.");
719
720 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
721 "Reference minimum: output is not a supported type.");
722
723 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
724 "Reference minimum: input 0 and Input 1 types are mismatched");
725
726 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
727 "Reference minimum: input and output types are mismatched");
728
729 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
730 "Reference minimum: shapes are not suitable for implicit broadcast.");
731
732 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000733}
734
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100735bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
736 const TensorInfo& input1,
737 const TensorInfo& output,
738 Optional<std::string&> reasonIfUnsupported) const
739{
Sadik Armagan2999a022019-04-09 14:20:12 +0100740 bool supported = true;
741
742 std::array<DataType,3> supportedTypes = {
743 DataType::Float32,
744 DataType::QuantisedAsymm8,
745 DataType::QuantisedSymm16
746 };
747
748 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
749 "Reference multiplication: input 0 is not a supported type.");
750
751 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
752 "Reference multiplication: input 1 is not a supported type.");
753
754 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
755 "Reference multiplication: output is not a supported type.");
756
757 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
758 "Reference multiplication: input 0 and Input 1 types are mismatched");
759
760 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
761 "Reference multiplication: input and output types are mismatched");
762
763 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
764 "Reference multiplication: shapes are not suitable for implicit broadcast.");
765
766 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100767}
768
769bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
770 const TensorInfo& output,
771 const NormalizationDescriptor& descriptor,
772 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100773{
774 ignore_unused(output);
775 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100776 return IsSupportedForDataTypeRef(reasonIfUnsupported,
777 input.GetDataType(),
778 &TrueFunc<>,
779 &FalseFuncU8<>);
780}
781
782bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
783 Optional<std::string&> reasonIfUnsupported) const
784{
kevmay012b4d88e2019-01-24 14:05:09 +0000785 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
786 output.GetDataType(),
787 &TrueFunc<>,
788 &TrueFunc<>,
789 &TrueFunc<>,
790 &FalseFuncI32<>,
791 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100792}
793
794bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
795 const TensorInfo& output,
796 const PadDescriptor& descriptor,
797 Optional<std::string&> reasonIfUnsupported) const
798{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100799 ignore_unused(output);
800 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000801 return IsSupportedForDataTypeRef(reasonIfUnsupported,
802 input.GetDataType(),
803 &TrueFunc<>,
804 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100805}
806
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100807bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
808 const TensorInfo& output,
809 const PermuteDescriptor& descriptor,
810 Optional<std::string&> reasonIfUnsupported) const
811{
812 ignore_unused(output);
813 ignore_unused(descriptor);
814 return IsSupportedForDataTypeRef(reasonIfUnsupported,
815 input.GetDataType(),
816 &TrueFunc<>,
817 &TrueFunc<>);
818}
819
820bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
821 const TensorInfo& output,
822 const Pooling2dDescriptor& descriptor,
823 Optional<std::string&> reasonIfUnsupported) const
824{
825 ignore_unused(output);
826 ignore_unused(descriptor);
827 return IsSupportedForDataTypeRef(reasonIfUnsupported,
828 input.GetDataType(),
829 &TrueFunc<>,
830 &TrueFunc<>);
831}
832
Derek Lamberti5f400d62019-03-25 15:41:58 +0000833bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
834 const TensorInfo& output,
835 Optional<std::string&> reasonIfUnsupported) const
836{
837 bool supported = true;
838
839 // Define supported output types.
840 std::array<DataType,2> supportedInputTypes = {
841 DataType::Float32,
842 };
843
844 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
845 "Reference quantize: input type not supported.");
846
847 // Define supported output types.
848 std::array<DataType,2> supportedOutputTypes = {
849 DataType::QuantisedAsymm8,
850 DataType::QuantisedSymm16
851 };
852 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
853 "Reference quantize: output type not supported.");
854
855 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
856 "Reference quantize: input and output shapes have different num total elements.");
857
858 return supported;
859}
860
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100861bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000862 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100863 Optional<std::string&> reasonIfUnsupported) const
864{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000865 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100866 return IsSupportedForDataTypeRef(reasonIfUnsupported,
867 input.GetDataType(),
868 &TrueFunc<>,
869 &TrueFunc<>);
870}
871
872bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000873 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100874 Optional<std::string&> reasonIfUnsupported) const
875{
Sadik Armaganc625f002018-12-17 11:32:16 +0000876 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100877 return IsSupportedForDataTypeRef(reasonIfUnsupported,
878 input.GetDataType(),
879 &TrueFunc<>,
880 &TrueFunc<>);
881}
882
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000883bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
884 const TensorInfo& output,
885 Optional<std::string&> reasonIfUnsupported) const
886{
887 ignore_unused(output);
888 return IsSupportedForDataTypeRef(reasonIfUnsupported,
889 input.GetDataType(),
890 &TrueFunc<>,
891 &FalseFuncU8<>);
892}
893
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100894bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
895 const TensorInfo& output,
896 const SoftmaxDescriptor& descriptor,
897 Optional<std::string&> reasonIfUnsupported) const
898{
899 ignore_unused(output);
900 ignore_unused(descriptor);
901 return IsSupportedForDataTypeRef(reasonIfUnsupported,
902 input.GetDataType(),
903 &TrueFunc<>,
904 &TrueFunc<>);
905}
906
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000907bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
908 const TensorInfo& output,
909 const SpaceToBatchNdDescriptor& descriptor,
910 Optional<std::string&> reasonIfUnsupported) const
911{
912 ignore_unused(output);
913 ignore_unused(descriptor);
914 return IsSupportedForDataTypeRef(reasonIfUnsupported,
915 input.GetDataType(),
916 &TrueFunc<>,
917 &TrueFunc<>);
918}
919
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100920bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
921 const ViewsDescriptor& descriptor,
922 Optional<std::string&> reasonIfUnsupported) const
923{
924 ignore_unused(descriptor);
925 return IsSupportedForDataTypeRef(reasonIfUnsupported,
926 input.GetDataType(),
927 &TrueFunc<>,
928 &TrueFunc<>);
929}
930
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000931bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
932 const TensorInfo& output,
933 const StridedSliceDescriptor& descriptor,
934 Optional<std::string&> reasonIfUnsupported) const
935{
936 ignore_unused(output);
937 ignore_unused(descriptor);
938 return IsSupportedForDataTypeRef(reasonIfUnsupported,
939 input.GetDataType(),
940 &TrueFunc<>,
941 &TrueFunc<>);
942}
943
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100944bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
945 const TensorInfo& input1,
946 const TensorInfo& output,
947 Optional<std::string&> reasonIfUnsupported) const
948{
Sadik Armagan2999a022019-04-09 14:20:12 +0100949 bool supported = true;
950
951 std::array<DataType,3> supportedTypes = {
952 DataType::Float32,
953 DataType::QuantisedAsymm8,
954 DataType::QuantisedSymm16
955 };
956
957 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
958 "Reference subtraction: input 0 is not a supported type.");
959
960 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
961 "Reference subtraction: input 1 is not a supported type.");
962
963 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
964 "Reference subtraction: output is not a supported type.");
965
966 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
967 "Reference subtraction: input 0 and Input 1 types are mismatched");
968
969 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
970 "Reference subtraction: input and output types are mismatched");
971
972 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
973 "Reference subtraction: shapes are not suitable for implicit broadcast.");
974
975 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100976}
977
arovir011c7c81b2018-10-08 11:34:28 +0100978} // namespace armnn