blob: f79c152139b07abdc5a1f0ebefe43f4abb3b37c7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010015
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
19#include <algorithm>
20#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
Derek Lamberti50db4e82019-03-13 14:16:15 +000049
50namespace
51{
52template<typename F>
53bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
54{
55 bool supported = rule();
56 if (!supported && reason)
57 {
58 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
59 }
60 return supported;
61}
62
63struct Rule
64{
65 bool operator()() const
66 {
67 return m_Res;
68 }
69
70 bool m_Res = true;
71};
72
Derek Lamberti2a434a82019-03-20 13:07:57 +000073template<typename T>
74bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000075{
76 return true;
77}
78
79template<typename T, typename... Rest>
80bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
81{
82 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
83
Derek Lamberti2a434a82019-03-20 13:07:57 +000084 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000085}
86
87struct TypesAreEqual : public Rule
88{
89 template<typename ... Ts>
90 TypesAreEqual(const Ts&... ts)
91 {
92 m_Res = AllTypesAreEqualImpl(ts...);
93 }
94};
95
96struct QuantizationParametersAreEqual : public Rule
97{
98 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
99 {
100 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
101 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
102 }
103};
104
105struct TypeAnyOf : public Rule
106{
107 template<typename Container>
108 TypeAnyOf(const TensorInfo& info, const Container& c)
109 {
110 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
111 {
112 return dt == info.GetDataType();
113 });
114 }
115};
116
117struct ShapesAreSameRank : public Rule
118{
119 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
120 {
121 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
122 }
123};
124
Derek Lamberti5f400d62019-03-25 15:41:58 +0000125struct ShapesAreSameTotalSize : public Rule
126{
127 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
128 {
129 m_Res = info0.GetNumElements() == info1.GetNumElements();
130 }
131};
132
Derek Lamberti50db4e82019-03-13 14:16:15 +0000133struct ShapesAreBroadcastCompatible : public Rule
134{
135 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
136 {
137 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
138 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
139 return sizeIn;
140 }
141
142 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
143 {
144 const TensorShape& shape0 = in0.GetShape();
145 const TensorShape& shape1 = in1.GetShape();
146 const TensorShape& outShape = out.GetShape();
147
148 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
149 {
150 unsigned int sizeOut = outShape[i];
151 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
152 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
153
154 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
155 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
156 }
157 }
158};
159} // namespace
160
161
arovir011c7c81b2018-10-08 11:34:28 +0100162bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
163 const TensorInfo& output,
164 const ActivationDescriptor& descriptor,
165 Optional<std::string&> reasonIfUnsupported) const
166{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000167 bool supported = true;
168
169 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100170 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000171 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100172 DataType::QuantisedAsymm8,
173 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000174 };
175
176 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
177 "Reference activation: input type not supported.");
178
179 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
180 "Reference activation: output type not supported.");
181
182 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
183 "Reference activation: input and output types mismatched.");
184
185 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
186 "Reference activation: input and output shapes are of different rank.");
187
188
189 struct ActivationFunctionSupported : public Rule
190 {
191 ActivationFunctionSupported(const ActivationDescriptor& desc)
192 {
193 switch(desc.m_Function)
194 {
195 case ActivationFunction::Abs:
196 case ActivationFunction::BoundedReLu:
197 case ActivationFunction::LeakyReLu:
198 case ActivationFunction::Linear:
199 case ActivationFunction::ReLu:
200 case ActivationFunction::Sigmoid:
201 case ActivationFunction::SoftReLu:
202 case ActivationFunction::Sqrt:
203 case ActivationFunction::Square:
204 case ActivationFunction::TanH:
205 {
206 m_Res = true;
207 break;
208 }
209 default:
210 {
211 m_Res = false;
212 break;
213 }
214 }
215 }
216 };
217
218 // Function is supported
219 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
220 "Reference activation: function not supported.");
221
222 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100223}
224
225bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
226 const TensorInfo& input1,
227 const TensorInfo& output,
228 Optional<std::string&> reasonIfUnsupported) const
229{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000230 bool supported = true;
231
Sadik Armagan2999a022019-04-09 14:20:12 +0100232 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000233 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100234 DataType::QuantisedAsymm8,
235 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000236 };
237
238 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
239 "Reference addition: input 0 is not a supported type.");
240
241 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
242 "Reference addition: input 1 is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
245 "Reference addition: output is not a supported type.");
246
247 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
248 "Reference addition: input 0 and Input 1 types are mismatched");
249
250 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
251 "Reference addition: input and output types are mismatched");
252
253 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
254 "Reference addition: shapes are not suitable for implicit broadcast.");
255
256 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100257}
258
259bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
260 const TensorInfo& output,
261 const TensorInfo& mean,
262 const TensorInfo& var,
263 const TensorInfo& beta,
264 const TensorInfo& gamma,
265 const BatchNormalizationDescriptor& descriptor,
266 Optional<std::string&> reasonIfUnsupported) const
267{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100268 ignore_unused(output);
269 ignore_unused(mean);
270 ignore_unused(var);
271 ignore_unused(beta);
272 ignore_unused(gamma);
273 ignore_unused(descriptor);
274 return IsSupportedForDataTypeRef(reasonIfUnsupported,
275 input.GetDataType(),
276 &TrueFunc<>,
277 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100278}
279
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000280bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const BatchToSpaceNdDescriptor& descriptor,
283 Optional<std::string&> reasonIfUnsupported) const
284{
285 ignore_unused(descriptor);
286 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
287 input.GetDataType(),
288 &TrueFunc<>,
289 &TrueFunc<>) &&
290 IsSupportedForDataTypeRef(reasonIfUnsupported,
291 output.GetDataType(),
292 &TrueFunc<>,
293 &TrueFunc<>));
294}
295
Jim Flynn906f9462019-05-10 13:55:21 +0100296bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
297 const TensorInfo& output,
298 const OriginsDescriptor& descriptor,
299 Optional<std::string&> reasonIfUnsupported) const
300{
301 ARMNN_NO_DEPRECATE_WARN_BEGIN
302 return IsMergerSupported(inputs, output, descriptor, reasonIfUnsupported);
303 ARMNN_NO_DEPRECATE_WARN_END
304}
305
arovir011c7c81b2018-10-08 11:34:28 +0100306bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
307 Optional<std::string&> reasonIfUnsupported) const
308{
narpra01db2b1602019-01-23 15:23:11 +0000309 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
310 output.GetDataType(),
311 &FalseFunc<>,
312 &TrueFunc<>,
313 &TrueFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000314 &TrueFunc<>,
315 &FalseFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100316}
317
318bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
319 const TensorInfo& output,
320 Optional<std::string&> reasonIfUnsupported) const
321{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100322 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
323 input.GetDataType(),
324 &TrueFunc<>,
325 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000326 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000327 &FalseFuncI32<>,
328 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100329 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
330 output.GetDataType(),
331 &FalseOutputFuncF16<>,
332 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000333 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000334 &FalseFuncI32<>,
335 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100336}
337
338bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
339 const TensorInfo& output,
340 Optional<std::string&> reasonIfUnsupported) const
341{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100342 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
343 input.GetDataType(),
344 &FalseInputFuncF16<>,
345 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000346 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000347 &FalseFuncI32<>,
348 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100349 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
350 output.GetDataType(),
351 &TrueFunc<>,
352 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000353 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000354 &FalseFuncI32<>,
355 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100356}
357
358bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
359 const TensorInfo& output,
360 const Convolution2dDescriptor& descriptor,
361 const TensorInfo& weights,
362 const Optional<TensorInfo>& biases,
363 Optional<std::string&> reasonIfUnsupported) const
364{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100365 ignore_unused(output);
366 ignore_unused(descriptor);
367 ignore_unused(weights);
368 ignore_unused(biases);
369 return IsSupportedForDataTypeRef(reasonIfUnsupported,
370 input.GetDataType(),
371 &TrueFunc<>,
372 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100373}
374
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000375bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
376 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000377 Optional<std::string&> reasonIfUnsupported) const
378{
379 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000380 return IsSupportedForDataTypeRef(reasonIfUnsupported,
381 input.GetDataType(),
382 &TrueFunc<>,
383 &TrueFunc<>);
384}
385
arovir011c7c81b2018-10-08 11:34:28 +0100386bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
387 const TensorInfo& output,
388 const DepthwiseConvolution2dDescriptor& descriptor,
389 const TensorInfo& weights,
390 const Optional<TensorInfo>& biases,
391 Optional<std::string&> reasonIfUnsupported) const
392{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100393 ignore_unused(output);
394 ignore_unused(descriptor);
395 ignore_unused(weights);
396 ignore_unused(biases);
397 return IsSupportedForDataTypeRef(reasonIfUnsupported,
398 input.GetDataType(),
399 &TrueFunc<>,
400 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100401}
402
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000403bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
404 const TensorInfo& output,
405 Optional<std::string&> reasonIfUnsupported) const
406{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100407 bool supported = true;
408
409 std::array<DataType,2> supportedInputTypes = {
410 DataType::QuantisedAsymm8,
411 DataType::QuantisedSymm16
412 };
413
414 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
415 "Reference dequantize: input type not supported.");
416
417 std::array<DataType,2> supportedOutputTypes = {
418 DataType::Float32,
419 };
420
421 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
422 "Reference dequantize: output type not supported.");
423
424 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
425 "Reference dequantize: input and output shapes have different num total elements.");
426
427 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000428}
429
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000430bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
431 const armnn::TensorInfo& input1,
432 const armnn::DetectionPostProcessDescriptor& descriptor,
433 armnn::Optional<std::string&> reasonIfUnsupported) const
434{
435 ignore_unused(input1);
436 return IsSupportedForDataTypeRef(reasonIfUnsupported,
437 input0.GetDataType(),
438 &TrueFunc<>,
439 &TrueFunc<>);
440}
441
Pablo Tellof0bd6832019-04-26 17:58:13 +0100442bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
443 const TensorInfo& output,
444 const DepthwiseConvolution2dDescriptor& descriptor,
445 const TensorInfo& weights,
446 const Optional<TensorInfo>& biases,
447 Optional<std::string&> reasonIfUnsupported) const
448{
449 if (descriptor.m_DilationY == 1 && descriptor.m_DilationY == 1)
450 {
451 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
452 }
453 else
454 {
455 if (reasonIfUnsupported)
456 {
457 reasonIfUnsupported.value() = "Reference Depthwise Convolution: Dilation parameters must be 1";
458 }
459 return false;
460 }
461}
462
463
464 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100465 const TensorInfo& input1,
466 const TensorInfo& output,
467 Optional<std::string&> reasonIfUnsupported) const
468{
Sadik Armagan2999a022019-04-09 14:20:12 +0100469 bool supported = true;
470
471 std::array<DataType,3> supportedTypes = {
472 DataType::Float32,
473 DataType::QuantisedAsymm8,
474 DataType::QuantisedSymm16
475 };
476
477 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
478 "Reference division: input 0 is not a supported type.");
479
480 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
481 "Reference division: input 1 is not a supported type.");
482
483 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
484 "Reference division: output is not a supported type.");
485
486 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
487 "Reference division: input 0 and Input 1 types are mismatched");
488
489 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
490 "Reference division: input and output types are mismatched");
491
492 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
493 "Reference division: shapes are not suitable for implicit broadcast.");
494
495 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100496}
497
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000498bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
499 const TensorInfo& input1,
500 const TensorInfo& output,
501 Optional<std::string&> reasonIfUnsupported) const
502{
503 ignore_unused(input0);
504 ignore_unused(input1);
505 ignore_unused(output);
506 ignore_unused(reasonIfUnsupported);
507 return IsSupportedForDataTypeRef(reasonIfUnsupported,
508 input0.GetDataType(),
509 &TrueFunc<>,
510 &TrueFunc<>);
511}
512
arovir011c7c81b2018-10-08 11:34:28 +0100513bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
514 const FakeQuantizationDescriptor& descriptor,
515 Optional<std::string&> reasonIfUnsupported) const
516{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100517 ignore_unused(descriptor);
518 return IsSupportedForDataTypeRef(reasonIfUnsupported,
519 input.GetDataType(),
520 &TrueFunc<>,
521 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100522}
523
524bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
525 const TensorInfo& output,
526 Optional<std::string&> reasonIfUnsupported) const
527{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100528 ignore_unused(output);
529 return IsSupportedForDataTypeRef(reasonIfUnsupported,
530 input.GetDataType(),
531 &TrueFunc<>,
532 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100533}
534
535bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
536 const TensorInfo& output,
537 const TensorInfo& weights,
538 const TensorInfo& biases,
539 const FullyConnectedDescriptor& descriptor,
540 Optional<std::string&> reasonIfUnsupported) const
541{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100542 ignore_unused(output);
543 ignore_unused(weights);
544 ignore_unused(biases);
545 ignore_unused(descriptor);
546 return IsSupportedForDataTypeRef(reasonIfUnsupported,
547 input.GetDataType(),
548 &TrueFunc<>,
549 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100550}
551
narpra014951d842019-01-18 16:53:53 +0000552bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
553 const armnn::TensorInfo& input1,
554 const armnn::TensorInfo& output,
555 armnn::Optional<std::string&> reasonIfUnsupported) const
556{
557 ignore_unused(input1);
558 ignore_unused(output);
559 return IsSupportedForDataTypeRef(reasonIfUnsupported,
560 input0.GetDataType(),
561 &TrueFunc<>,
562 &TrueFunc<>);
563}
564
FrancisMurtagh878f0232018-12-19 10:56:15 +0000565bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
566 const TensorInfo& input1,
567 const TensorInfo& output,
568 Optional<std::string&> reasonIfUnsupported) const
569{
570 ignore_unused(input0);
571 ignore_unused(input1);
572 ignore_unused(output);
573 ignore_unused(reasonIfUnsupported);
574 return IsSupportedForDataTypeRef(reasonIfUnsupported,
575 input0.GetDataType(),
576 &TrueFunc<>,
577 &TrueFunc<>);
578}
579
arovir011c7c81b2018-10-08 11:34:28 +0100580bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
581 Optional<std::string&> reasonIfUnsupported) const
582{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100583 return IsSupportedForDataTypeRef(reasonIfUnsupported,
584 input.GetDataType(),
585 &TrueFunc<>,
586 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100587}
588
589bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
590 const TensorInfo& output,
591 const L2NormalizationDescriptor& descriptor,
592 Optional<std::string&> reasonIfUnsupported) const
593{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100594 ignore_unused(output);
595 ignore_unused(descriptor);
596 return IsSupportedForDataTypeRef(reasonIfUnsupported,
597 input.GetDataType(),
598 &TrueFunc<>,
599 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100600}
601
602bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
603 const TensorInfo& outputStateIn,
604 const TensorInfo& cellStateIn,
605 const TensorInfo& scratchBuffer,
606 const TensorInfo& outputStateOut,
607 const TensorInfo& cellStateOut,
608 const TensorInfo& output,
609 const LstmDescriptor& descriptor,
610 const TensorInfo& inputToForgetWeights,
611 const TensorInfo& inputToCellWeights,
612 const TensorInfo& inputToOutputWeights,
613 const TensorInfo& recurrentToForgetWeights,
614 const TensorInfo& recurrentToCellWeights,
615 const TensorInfo& recurrentToOutputWeights,
616 const TensorInfo& forgetGateBias,
617 const TensorInfo& cellBias,
618 const TensorInfo& outputGateBias,
619 const TensorInfo* inputToInputWeights,
620 const TensorInfo* recurrentToInputWeights,
621 const TensorInfo* cellToInputWeights,
622 const TensorInfo* inputGateBias,
623 const TensorInfo* projectionWeights,
624 const TensorInfo* projectionBias,
625 const TensorInfo* cellToForgetWeights,
626 const TensorInfo* cellToOutputWeights,
627 Optional<std::string&> reasonIfUnsupported) const
628{
telsoa01c577f2c2018-08-31 09:22:23 +0100629 ignore_unused(descriptor);
630 ignore_unused(inputToForgetWeights);
631 ignore_unused(inputToCellWeights);
632 ignore_unused(inputToOutputWeights);
633 ignore_unused(recurrentToForgetWeights);
634 ignore_unused(recurrentToCellWeights);
635 ignore_unused(recurrentToOutputWeights);
636 ignore_unused(forgetGateBias);
637 ignore_unused(cellBias);
638 ignore_unused(outputGateBias);
639 ignore_unused(inputToInputWeights);
640 ignore_unused(recurrentToInputWeights);
641 ignore_unused(cellToInputWeights);
642 ignore_unused(inputGateBias);
643 ignore_unused(projectionWeights);
644 ignore_unused(projectionBias);
645 ignore_unused(cellToForgetWeights);
646 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100647
648 bool supported = true;
649
650 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100651 DataType::Float32,
652 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100653 };
654
655 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
656 "Reference Lstm: input is not a supported type.");
657
658 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
659 "Reference Lstm: input and outputStateIn types are mismatched");
660
661 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
662 "Reference Lstm: input and cellStateIn types are mismatched");
663
664 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
665 "Reference Lstm: input and scratchBuffer types are mismatched");
666
667 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
668 "Reference Lstm: input and outputStateOut types are mismatched");
669
670 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
671 "Reference Lstm: input and cellStateOut types are mismatched");
672
673 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
674 "Reference Lstm: input and output types are mismatched");
675
676 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100677}
678
saoste012df12b32018-11-28 16:57:20 +0000679bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
680 const TensorInfo& input1,
681 const TensorInfo& output,
682 Optional<std::string&> reasonIfUnsupported) const
683{
Sadik Armagan2999a022019-04-09 14:20:12 +0100684 bool supported = true;
685
686 std::array<DataType,3> supportedTypes = {
687 DataType::Float32,
688 DataType::QuantisedAsymm8,
689 DataType::QuantisedSymm16
690 };
691
692 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
693 "Reference maximum: input 0 is not a supported type.");
694
695 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
696 "Reference maximum: input 1 is not a supported type.");
697
698 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
699 "Reference maximum: output is not a supported type.");
700
701 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
702 "Reference maximum: input 0 and Input 1 types are mismatched");
703
704 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
705 "Reference maximum: input and output types are mismatched");
706
707 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
708 "Reference maximum: shapes are not suitable for implicit broadcast.");
709
710 return supported;
saoste012df12b32018-11-28 16:57:20 +0000711}
712
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100713bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
714 const TensorInfo& output,
715 const MeanDescriptor& descriptor,
716 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100717{
narpra011e4c31d2018-09-28 11:07:51 +0100718 ignore_unused(output);
719 ignore_unused(descriptor);
720 return IsSupportedForDataTypeRef(reasonIfUnsupported,
721 input.GetDataType(),
722 &TrueFunc<>,
723 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100724}
725
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100726bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000727 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100728 const OriginsDescriptor& descriptor,
729 Optional<std::string&> reasonIfUnsupported) const
730{
731 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000732 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100733 return IsSupportedForDataTypeRef(reasonIfUnsupported,
734 inputs[0]->GetDataType(),
735 &TrueFunc<>,
736 &TrueFunc<>);
737}
738
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000739bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
740 const TensorInfo &output,
741 Optional<std::string &> reasonIfUnsupported) const
742{
743 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000744 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
745 input.GetDataType(),
746 &TrueFunc<>,
747 &TrueFunc<>,
748 &TrueFunc<>,
749 &FalseFuncI32<>,
750 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000751}
752
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000753bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
754 const TensorInfo& input1,
755 const TensorInfo& output,
756 Optional<std::string&> reasonIfUnsupported) const
757{
Sadik Armagan2999a022019-04-09 14:20:12 +0100758 bool supported = true;
759
760 std::array<DataType,3> supportedTypes = {
761 DataType::Float32,
762 DataType::QuantisedAsymm8,
763 DataType::QuantisedSymm16
764 };
765
766 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
767 "Reference minimum: input 0 is not a supported type.");
768
769 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
770 "Reference minimum: input 1 is not a supported type.");
771
772 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
773 "Reference minimum: output is not a supported type.");
774
775 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
776 "Reference minimum: input 0 and Input 1 types are mismatched");
777
778 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
779 "Reference minimum: input and output types are mismatched");
780
781 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
782 "Reference minimum: shapes are not suitable for implicit broadcast.");
783
784 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000785}
786
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100787bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
788 const TensorInfo& input1,
789 const TensorInfo& output,
790 Optional<std::string&> reasonIfUnsupported) const
791{
Sadik Armagan2999a022019-04-09 14:20:12 +0100792 bool supported = true;
793
794 std::array<DataType,3> supportedTypes = {
795 DataType::Float32,
796 DataType::QuantisedAsymm8,
797 DataType::QuantisedSymm16
798 };
799
800 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
801 "Reference multiplication: input 0 is not a supported type.");
802
803 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
804 "Reference multiplication: input 1 is not a supported type.");
805
806 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
807 "Reference multiplication: output is not a supported type.");
808
809 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
810 "Reference multiplication: input 0 and Input 1 types are mismatched");
811
812 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
813 "Reference multiplication: input and output types are mismatched");
814
815 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
816 "Reference multiplication: shapes are not suitable for implicit broadcast.");
817
818 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100819}
820
821bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
822 const TensorInfo& output,
823 const NormalizationDescriptor& descriptor,
824 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100825{
826 ignore_unused(output);
827 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100828 return IsSupportedForDataTypeRef(reasonIfUnsupported,
829 input.GetDataType(),
830 &TrueFunc<>,
831 &FalseFuncU8<>);
832}
833
834bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
835 Optional<std::string&> reasonIfUnsupported) const
836{
kevmay012b4d88e2019-01-24 14:05:09 +0000837 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
838 output.GetDataType(),
839 &TrueFunc<>,
840 &TrueFunc<>,
841 &TrueFunc<>,
842 &FalseFuncI32<>,
843 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100844}
845
846bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
847 const TensorInfo& output,
848 const PadDescriptor& descriptor,
849 Optional<std::string&> reasonIfUnsupported) const
850{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100851 ignore_unused(output);
852 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000853 return IsSupportedForDataTypeRef(reasonIfUnsupported,
854 input.GetDataType(),
855 &TrueFunc<>,
856 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100857}
858
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100859bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
860 const TensorInfo& output,
861 const PermuteDescriptor& descriptor,
862 Optional<std::string&> reasonIfUnsupported) const
863{
864 ignore_unused(output);
865 ignore_unused(descriptor);
866 return IsSupportedForDataTypeRef(reasonIfUnsupported,
867 input.GetDataType(),
868 &TrueFunc<>,
869 &TrueFunc<>);
870}
871
872bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
873 const TensorInfo& output,
874 const Pooling2dDescriptor& descriptor,
875 Optional<std::string&> reasonIfUnsupported) const
876{
877 ignore_unused(output);
878 ignore_unused(descriptor);
879 return IsSupportedForDataTypeRef(reasonIfUnsupported,
880 input.GetDataType(),
881 &TrueFunc<>,
882 &TrueFunc<>);
883}
884
Derek Lamberti5f400d62019-03-25 15:41:58 +0000885bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
886 const TensorInfo& output,
887 Optional<std::string&> reasonIfUnsupported) const
888{
889 bool supported = true;
890
891 // Define supported output types.
892 std::array<DataType,2> supportedInputTypes = {
893 DataType::Float32,
894 };
895
896 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
897 "Reference quantize: input type not supported.");
898
899 // Define supported output types.
900 std::array<DataType,2> supportedOutputTypes = {
901 DataType::QuantisedAsymm8,
902 DataType::QuantisedSymm16
903 };
904 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
905 "Reference quantize: output type not supported.");
906
907 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
908 "Reference quantize: input and output shapes have different num total elements.");
909
910 return supported;
911}
912
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100913bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000914 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100915 Optional<std::string&> reasonIfUnsupported) const
916{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000917 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100918 return IsSupportedForDataTypeRef(reasonIfUnsupported,
919 input.GetDataType(),
920 &TrueFunc<>,
921 &TrueFunc<>);
922}
923
924bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000925 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100926 Optional<std::string&> reasonIfUnsupported) const
927{
Sadik Armaganc625f002018-12-17 11:32:16 +0000928 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100929 return IsSupportedForDataTypeRef(reasonIfUnsupported,
930 input.GetDataType(),
931 &TrueFunc<>,
932 &TrueFunc<>);
933}
934
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000935bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
936 const TensorInfo& output,
937 Optional<std::string&> reasonIfUnsupported) const
938{
939 ignore_unused(output);
940 return IsSupportedForDataTypeRef(reasonIfUnsupported,
941 input.GetDataType(),
942 &TrueFunc<>,
943 &FalseFuncU8<>);
944}
945
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100946bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
947 const TensorInfo& output,
948 const SoftmaxDescriptor& descriptor,
949 Optional<std::string&> reasonIfUnsupported) const
950{
951 ignore_unused(output);
952 ignore_unused(descriptor);
953 return IsSupportedForDataTypeRef(reasonIfUnsupported,
954 input.GetDataType(),
955 &TrueFunc<>,
956 &TrueFunc<>);
957}
958
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000959bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
960 const TensorInfo& output,
961 const SpaceToBatchNdDescriptor& descriptor,
962 Optional<std::string&> reasonIfUnsupported) const
963{
964 ignore_unused(output);
965 ignore_unused(descriptor);
966 return IsSupportedForDataTypeRef(reasonIfUnsupported,
967 input.GetDataType(),
968 &TrueFunc<>,
969 &TrueFunc<>);
970}
971
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100972bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
973 const ViewsDescriptor& descriptor,
974 Optional<std::string&> reasonIfUnsupported) const
975{
976 ignore_unused(descriptor);
977 return IsSupportedForDataTypeRef(reasonIfUnsupported,
978 input.GetDataType(),
979 &TrueFunc<>,
980 &TrueFunc<>);
981}
982
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000983bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
984 const TensorInfo& output,
985 const StridedSliceDescriptor& descriptor,
986 Optional<std::string&> reasonIfUnsupported) const
987{
988 ignore_unused(output);
989 ignore_unused(descriptor);
990 return IsSupportedForDataTypeRef(reasonIfUnsupported,
991 input.GetDataType(),
992 &TrueFunc<>,
993 &TrueFunc<>);
994}
995
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100996bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
997 const TensorInfo& input1,
998 const TensorInfo& output,
999 Optional<std::string&> reasonIfUnsupported) const
1000{
Sadik Armagan2999a022019-04-09 14:20:12 +01001001 bool supported = true;
1002
1003 std::array<DataType,3> supportedTypes = {
1004 DataType::Float32,
1005 DataType::QuantisedAsymm8,
1006 DataType::QuantisedSymm16
1007 };
1008
1009 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1010 "Reference subtraction: input 0 is not a supported type.");
1011
1012 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1013 "Reference subtraction: input 1 is not a supported type.");
1014
1015 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1016 "Reference subtraction: output is not a supported type.");
1017
1018 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1019 "Reference subtraction: input 0 and Input 1 types are mismatched");
1020
1021 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1022 "Reference subtraction: input and output types are mismatched");
1023
1024 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1025 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1026
1027 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001028}
1029
arovir011c7c81b2018-10-08 11:34:28 +01001030} // namespace armnn