blob: 3512d52acfd1929486e936b0c4ce0fe5f310c4aa [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010015
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
19#include <algorithm>
20#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
Derek Lamberti50db4e82019-03-13 14:16:15 +000049
50namespace
51{
52template<typename F>
53bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
54{
55 bool supported = rule();
56 if (!supported && reason)
57 {
58 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
59 }
60 return supported;
61}
62
63struct Rule
64{
65 bool operator()() const
66 {
67 return m_Res;
68 }
69
70 bool m_Res = true;
71};
72
Derek Lamberti2a434a82019-03-20 13:07:57 +000073template<typename T>
74bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000075{
76 return true;
77}
78
79template<typename T, typename... Rest>
80bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
81{
82 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
83
Derek Lamberti2a434a82019-03-20 13:07:57 +000084 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000085}
86
87struct TypesAreEqual : public Rule
88{
89 template<typename ... Ts>
90 TypesAreEqual(const Ts&... ts)
91 {
92 m_Res = AllTypesAreEqualImpl(ts...);
93 }
94};
95
96struct QuantizationParametersAreEqual : public Rule
97{
98 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
99 {
100 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
101 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
102 }
103};
104
105struct TypeAnyOf : public Rule
106{
107 template<typename Container>
108 TypeAnyOf(const TensorInfo& info, const Container& c)
109 {
110 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
111 {
112 return dt == info.GetDataType();
113 });
114 }
115};
116
117struct ShapesAreSameRank : public Rule
118{
119 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
120 {
121 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
122 }
123};
124
Derek Lamberti5f400d62019-03-25 15:41:58 +0000125struct ShapesAreSameTotalSize : public Rule
126{
127 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
128 {
129 m_Res = info0.GetNumElements() == info1.GetNumElements();
130 }
131};
132
Derek Lamberti50db4e82019-03-13 14:16:15 +0000133struct ShapesAreBroadcastCompatible : public Rule
134{
135 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
136 {
137 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
138 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
139 return sizeIn;
140 }
141
142 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
143 {
144 const TensorShape& shape0 = in0.GetShape();
145 const TensorShape& shape1 = in1.GetShape();
146 const TensorShape& outShape = out.GetShape();
147
148 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
149 {
150 unsigned int sizeOut = outShape[i];
151 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
152 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
153
154 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
155 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
156 }
157 }
158};
159} // namespace
160
161
arovir011c7c81b2018-10-08 11:34:28 +0100162bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
163 const TensorInfo& output,
164 const ActivationDescriptor& descriptor,
165 Optional<std::string&> reasonIfUnsupported) const
166{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000167 bool supported = true;
168
169 // Define supported types.
170 std::array<DataType,2> supportedTypes = {
171 DataType::Float32,
172 DataType::QuantisedAsymm8
173 };
174
175 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
176 "Reference activation: input type not supported.");
177
178 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
179 "Reference activation: output type not supported.");
180
181 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
182 "Reference activation: input and output types mismatched.");
183
184 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
185 "Reference activation: input and output shapes are of different rank.");
186
187
188 struct ActivationFunctionSupported : public Rule
189 {
190 ActivationFunctionSupported(const ActivationDescriptor& desc)
191 {
192 switch(desc.m_Function)
193 {
194 case ActivationFunction::Abs:
195 case ActivationFunction::BoundedReLu:
196 case ActivationFunction::LeakyReLu:
197 case ActivationFunction::Linear:
198 case ActivationFunction::ReLu:
199 case ActivationFunction::Sigmoid:
200 case ActivationFunction::SoftReLu:
201 case ActivationFunction::Sqrt:
202 case ActivationFunction::Square:
203 case ActivationFunction::TanH:
204 {
205 m_Res = true;
206 break;
207 }
208 default:
209 {
210 m_Res = false;
211 break;
212 }
213 }
214 }
215 };
216
217 // Function is supported
218 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
219 "Reference activation: function not supported.");
220
221 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100222}
223
224bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
225 const TensorInfo& input1,
226 const TensorInfo& output,
227 Optional<std::string&> reasonIfUnsupported) const
228{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000229 bool supported = true;
230
Sadik Armagan2999a022019-04-09 14:20:12 +0100231 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000232 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100233 DataType::QuantisedAsymm8,
234 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000235 };
236
237 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
238 "Reference addition: input 0 is not a supported type.");
239
240 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
241 "Reference addition: input 1 is not a supported type.");
242
243 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
244 "Reference addition: output is not a supported type.");
245
246 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
247 "Reference addition: input 0 and Input 1 types are mismatched");
248
249 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
250 "Reference addition: input and output types are mismatched");
251
252 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
253 "Reference addition: shapes are not suitable for implicit broadcast.");
254
255 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100256}
257
258bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
259 const TensorInfo& output,
260 const TensorInfo& mean,
261 const TensorInfo& var,
262 const TensorInfo& beta,
263 const TensorInfo& gamma,
264 const BatchNormalizationDescriptor& descriptor,
265 Optional<std::string&> reasonIfUnsupported) const
266{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100267 ignore_unused(output);
268 ignore_unused(mean);
269 ignore_unused(var);
270 ignore_unused(beta);
271 ignore_unused(gamma);
272 ignore_unused(descriptor);
273 return IsSupportedForDataTypeRef(reasonIfUnsupported,
274 input.GetDataType(),
275 &TrueFunc<>,
276 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100277}
278
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000279bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
280 const TensorInfo& output,
281 const BatchToSpaceNdDescriptor& descriptor,
282 Optional<std::string&> reasonIfUnsupported) const
283{
284 ignore_unused(descriptor);
285 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
286 input.GetDataType(),
287 &TrueFunc<>,
288 &TrueFunc<>) &&
289 IsSupportedForDataTypeRef(reasonIfUnsupported,
290 output.GetDataType(),
291 &TrueFunc<>,
292 &TrueFunc<>));
293}
294
arovir011c7c81b2018-10-08 11:34:28 +0100295bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
296 Optional<std::string&> reasonIfUnsupported) const
297{
narpra01db2b1602019-01-23 15:23:11 +0000298 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
299 output.GetDataType(),
300 &FalseFunc<>,
301 &TrueFunc<>,
302 &TrueFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000303 &TrueFunc<>,
304 &FalseFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100305}
306
307bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
308 const TensorInfo& output,
309 Optional<std::string&> reasonIfUnsupported) const
310{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100311 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
312 input.GetDataType(),
313 &TrueFunc<>,
314 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000315 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000316 &FalseFuncI32<>,
317 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100318 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
319 output.GetDataType(),
320 &FalseOutputFuncF16<>,
321 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000322 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000323 &FalseFuncI32<>,
324 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100325}
326
327bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
328 const TensorInfo& output,
329 Optional<std::string&> reasonIfUnsupported) const
330{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100331 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
332 input.GetDataType(),
333 &FalseInputFuncF16<>,
334 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000335 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000336 &FalseFuncI32<>,
337 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100338 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
339 output.GetDataType(),
340 &TrueFunc<>,
341 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000342 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000343 &FalseFuncI32<>,
344 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100345}
346
347bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
348 const TensorInfo& output,
349 const Convolution2dDescriptor& descriptor,
350 const TensorInfo& weights,
351 const Optional<TensorInfo>& biases,
352 Optional<std::string&> reasonIfUnsupported) const
353{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100354 ignore_unused(output);
355 ignore_unused(descriptor);
356 ignore_unused(weights);
357 ignore_unused(biases);
358 return IsSupportedForDataTypeRef(reasonIfUnsupported,
359 input.GetDataType(),
360 &TrueFunc<>,
361 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100362}
363
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000364bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
365 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000366 Optional<std::string&> reasonIfUnsupported) const
367{
368 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000369 return IsSupportedForDataTypeRef(reasonIfUnsupported,
370 input.GetDataType(),
371 &TrueFunc<>,
372 &TrueFunc<>);
373}
374
arovir011c7c81b2018-10-08 11:34:28 +0100375bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
376 const TensorInfo& output,
377 const DepthwiseConvolution2dDescriptor& descriptor,
378 const TensorInfo& weights,
379 const Optional<TensorInfo>& biases,
380 Optional<std::string&> reasonIfUnsupported) const
381{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100382 ignore_unused(output);
383 ignore_unused(descriptor);
384 ignore_unused(weights);
385 ignore_unused(biases);
386 return IsSupportedForDataTypeRef(reasonIfUnsupported,
387 input.GetDataType(),
388 &TrueFunc<>,
389 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100390}
391
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000392bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
393 const TensorInfo& output,
394 Optional<std::string&> reasonIfUnsupported) const
395{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100396 bool supported = true;
397
398 std::array<DataType,2> supportedInputTypes = {
399 DataType::QuantisedAsymm8,
400 DataType::QuantisedSymm16
401 };
402
403 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
404 "Reference dequantize: input type not supported.");
405
406 std::array<DataType,2> supportedOutputTypes = {
407 DataType::Float32,
408 };
409
410 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
411 "Reference dequantize: output type not supported.");
412
413 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
414 "Reference dequantize: input and output shapes have different num total elements.");
415
416 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000417}
418
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000419bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
420 const armnn::TensorInfo& input1,
421 const armnn::DetectionPostProcessDescriptor& descriptor,
422 armnn::Optional<std::string&> reasonIfUnsupported) const
423{
424 ignore_unused(input1);
425 return IsSupportedForDataTypeRef(reasonIfUnsupported,
426 input0.GetDataType(),
427 &TrueFunc<>,
428 &TrueFunc<>);
429}
430
arovir011c7c81b2018-10-08 11:34:28 +0100431bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
432 const TensorInfo& input1,
433 const TensorInfo& output,
434 Optional<std::string&> reasonIfUnsupported) const
435{
Sadik Armagan2999a022019-04-09 14:20:12 +0100436 bool supported = true;
437
438 std::array<DataType,3> supportedTypes = {
439 DataType::Float32,
440 DataType::QuantisedAsymm8,
441 DataType::QuantisedSymm16
442 };
443
444 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
445 "Reference division: input 0 is not a supported type.");
446
447 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
448 "Reference division: input 1 is not a supported type.");
449
450 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
451 "Reference division: output is not a supported type.");
452
453 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
454 "Reference division: input 0 and Input 1 types are mismatched");
455
456 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
457 "Reference division: input and output types are mismatched");
458
459 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
460 "Reference division: shapes are not suitable for implicit broadcast.");
461
462 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100463}
464
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000465bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
466 const TensorInfo& input1,
467 const TensorInfo& output,
468 Optional<std::string&> reasonIfUnsupported) const
469{
470 ignore_unused(input0);
471 ignore_unused(input1);
472 ignore_unused(output);
473 ignore_unused(reasonIfUnsupported);
474 return IsSupportedForDataTypeRef(reasonIfUnsupported,
475 input0.GetDataType(),
476 &TrueFunc<>,
477 &TrueFunc<>);
478}
479
arovir011c7c81b2018-10-08 11:34:28 +0100480bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
481 const FakeQuantizationDescriptor& descriptor,
482 Optional<std::string&> reasonIfUnsupported) const
483{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100484 ignore_unused(descriptor);
485 return IsSupportedForDataTypeRef(reasonIfUnsupported,
486 input.GetDataType(),
487 &TrueFunc<>,
488 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100489}
490
491bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
492 const TensorInfo& output,
493 Optional<std::string&> reasonIfUnsupported) const
494{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100495 ignore_unused(output);
496 return IsSupportedForDataTypeRef(reasonIfUnsupported,
497 input.GetDataType(),
498 &TrueFunc<>,
499 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100500}
501
502bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
503 const TensorInfo& output,
504 const TensorInfo& weights,
505 const TensorInfo& biases,
506 const FullyConnectedDescriptor& descriptor,
507 Optional<std::string&> reasonIfUnsupported) const
508{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100509 ignore_unused(output);
510 ignore_unused(weights);
511 ignore_unused(biases);
512 ignore_unused(descriptor);
513 return IsSupportedForDataTypeRef(reasonIfUnsupported,
514 input.GetDataType(),
515 &TrueFunc<>,
516 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100517}
518
narpra014951d842019-01-18 16:53:53 +0000519bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
520 const armnn::TensorInfo& input1,
521 const armnn::TensorInfo& output,
522 armnn::Optional<std::string&> reasonIfUnsupported) const
523{
524 ignore_unused(input1);
525 ignore_unused(output);
526 return IsSupportedForDataTypeRef(reasonIfUnsupported,
527 input0.GetDataType(),
528 &TrueFunc<>,
529 &TrueFunc<>);
530}
531
FrancisMurtagh878f0232018-12-19 10:56:15 +0000532bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
533 const TensorInfo& input1,
534 const TensorInfo& output,
535 Optional<std::string&> reasonIfUnsupported) const
536{
537 ignore_unused(input0);
538 ignore_unused(input1);
539 ignore_unused(output);
540 ignore_unused(reasonIfUnsupported);
541 return IsSupportedForDataTypeRef(reasonIfUnsupported,
542 input0.GetDataType(),
543 &TrueFunc<>,
544 &TrueFunc<>);
545}
546
arovir011c7c81b2018-10-08 11:34:28 +0100547bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
548 Optional<std::string&> reasonIfUnsupported) const
549{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100550 return IsSupportedForDataTypeRef(reasonIfUnsupported,
551 input.GetDataType(),
552 &TrueFunc<>,
553 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100554}
555
556bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
557 const TensorInfo& output,
558 const L2NormalizationDescriptor& descriptor,
559 Optional<std::string&> reasonIfUnsupported) const
560{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100561 ignore_unused(output);
562 ignore_unused(descriptor);
563 return IsSupportedForDataTypeRef(reasonIfUnsupported,
564 input.GetDataType(),
565 &TrueFunc<>,
566 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100567}
568
569bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
570 const TensorInfo& outputStateIn,
571 const TensorInfo& cellStateIn,
572 const TensorInfo& scratchBuffer,
573 const TensorInfo& outputStateOut,
574 const TensorInfo& cellStateOut,
575 const TensorInfo& output,
576 const LstmDescriptor& descriptor,
577 const TensorInfo& inputToForgetWeights,
578 const TensorInfo& inputToCellWeights,
579 const TensorInfo& inputToOutputWeights,
580 const TensorInfo& recurrentToForgetWeights,
581 const TensorInfo& recurrentToCellWeights,
582 const TensorInfo& recurrentToOutputWeights,
583 const TensorInfo& forgetGateBias,
584 const TensorInfo& cellBias,
585 const TensorInfo& outputGateBias,
586 const TensorInfo* inputToInputWeights,
587 const TensorInfo* recurrentToInputWeights,
588 const TensorInfo* cellToInputWeights,
589 const TensorInfo* inputGateBias,
590 const TensorInfo* projectionWeights,
591 const TensorInfo* projectionBias,
592 const TensorInfo* cellToForgetWeights,
593 const TensorInfo* cellToOutputWeights,
594 Optional<std::string&> reasonIfUnsupported) const
595{
telsoa01c577f2c2018-08-31 09:22:23 +0100596 ignore_unused(outputStateIn);
597 ignore_unused(cellStateIn);
598 ignore_unused(scratchBuffer);
599 ignore_unused(outputStateOut);
600 ignore_unused(cellStateOut);
601 ignore_unused(output);
602 ignore_unused(descriptor);
603 ignore_unused(inputToForgetWeights);
604 ignore_unused(inputToCellWeights);
605 ignore_unused(inputToOutputWeights);
606 ignore_unused(recurrentToForgetWeights);
607 ignore_unused(recurrentToCellWeights);
608 ignore_unused(recurrentToOutputWeights);
609 ignore_unused(forgetGateBias);
610 ignore_unused(cellBias);
611 ignore_unused(outputGateBias);
612 ignore_unused(inputToInputWeights);
613 ignore_unused(recurrentToInputWeights);
614 ignore_unused(cellToInputWeights);
615 ignore_unused(inputGateBias);
616 ignore_unused(projectionWeights);
617 ignore_unused(projectionBias);
618 ignore_unused(cellToForgetWeights);
619 ignore_unused(cellToOutputWeights);
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000620 return IsSupportedForDataTypeRef(reasonIfUnsupported,
621 input.GetDataType(),
622 &TrueFunc<>,
623 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100624}
625
saoste012df12b32018-11-28 16:57:20 +0000626bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
627 const TensorInfo& input1,
628 const TensorInfo& output,
629 Optional<std::string&> reasonIfUnsupported) const
630{
Sadik Armagan2999a022019-04-09 14:20:12 +0100631 bool supported = true;
632
633 std::array<DataType,3> supportedTypes = {
634 DataType::Float32,
635 DataType::QuantisedAsymm8,
636 DataType::QuantisedSymm16
637 };
638
639 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
640 "Reference maximum: input 0 is not a supported type.");
641
642 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
643 "Reference maximum: input 1 is not a supported type.");
644
645 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
646 "Reference maximum: output is not a supported type.");
647
648 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
649 "Reference maximum: input 0 and Input 1 types are mismatched");
650
651 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
652 "Reference maximum: input and output types are mismatched");
653
654 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
655 "Reference maximum: shapes are not suitable for implicit broadcast.");
656
657 return supported;
saoste012df12b32018-11-28 16:57:20 +0000658}
659
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100660bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
661 const TensorInfo& output,
662 const MeanDescriptor& descriptor,
663 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100664{
narpra011e4c31d2018-09-28 11:07:51 +0100665 ignore_unused(output);
666 ignore_unused(descriptor);
667 return IsSupportedForDataTypeRef(reasonIfUnsupported,
668 input.GetDataType(),
669 &TrueFunc<>,
670 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100671}
672
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100673bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000674 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100675 const OriginsDescriptor& descriptor,
676 Optional<std::string&> reasonIfUnsupported) const
677{
678 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000679 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100680 return IsSupportedForDataTypeRef(reasonIfUnsupported,
681 inputs[0]->GetDataType(),
682 &TrueFunc<>,
683 &TrueFunc<>);
684}
685
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000686bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
687 const TensorInfo &output,
688 Optional<std::string &> reasonIfUnsupported) const
689{
690 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000691 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
692 input.GetDataType(),
693 &TrueFunc<>,
694 &TrueFunc<>,
695 &TrueFunc<>,
696 &FalseFuncI32<>,
697 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000698}
699
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000700bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
701 const TensorInfo& input1,
702 const TensorInfo& output,
703 Optional<std::string&> reasonIfUnsupported) const
704{
Sadik Armagan2999a022019-04-09 14:20:12 +0100705 bool supported = true;
706
707 std::array<DataType,3> supportedTypes = {
708 DataType::Float32,
709 DataType::QuantisedAsymm8,
710 DataType::QuantisedSymm16
711 };
712
713 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
714 "Reference minimum: input 0 is not a supported type.");
715
716 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
717 "Reference minimum: input 1 is not a supported type.");
718
719 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
720 "Reference minimum: output is not a supported type.");
721
722 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
723 "Reference minimum: input 0 and Input 1 types are mismatched");
724
725 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
726 "Reference minimum: input and output types are mismatched");
727
728 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
729 "Reference minimum: shapes are not suitable for implicit broadcast.");
730
731 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000732}
733
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100734bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
735 const TensorInfo& input1,
736 const TensorInfo& output,
737 Optional<std::string&> reasonIfUnsupported) const
738{
Sadik Armagan2999a022019-04-09 14:20:12 +0100739 bool supported = true;
740
741 std::array<DataType,3> supportedTypes = {
742 DataType::Float32,
743 DataType::QuantisedAsymm8,
744 DataType::QuantisedSymm16
745 };
746
747 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
748 "Reference multiplication: input 0 is not a supported type.");
749
750 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
751 "Reference multiplication: input 1 is not a supported type.");
752
753 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
754 "Reference multiplication: output is not a supported type.");
755
756 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
757 "Reference multiplication: input 0 and Input 1 types are mismatched");
758
759 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
760 "Reference multiplication: input and output types are mismatched");
761
762 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
763 "Reference multiplication: shapes are not suitable for implicit broadcast.");
764
765 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100766}
767
768bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
769 const TensorInfo& output,
770 const NormalizationDescriptor& descriptor,
771 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100772{
773 ignore_unused(output);
774 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100775 return IsSupportedForDataTypeRef(reasonIfUnsupported,
776 input.GetDataType(),
777 &TrueFunc<>,
778 &FalseFuncU8<>);
779}
780
781bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
782 Optional<std::string&> reasonIfUnsupported) const
783{
kevmay012b4d88e2019-01-24 14:05:09 +0000784 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
785 output.GetDataType(),
786 &TrueFunc<>,
787 &TrueFunc<>,
788 &TrueFunc<>,
789 &FalseFuncI32<>,
790 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100791}
792
793bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
794 const TensorInfo& output,
795 const PadDescriptor& descriptor,
796 Optional<std::string&> reasonIfUnsupported) const
797{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100798 ignore_unused(output);
799 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000800 return IsSupportedForDataTypeRef(reasonIfUnsupported,
801 input.GetDataType(),
802 &TrueFunc<>,
803 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100804}
805
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100806bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
807 const TensorInfo& output,
808 const PermuteDescriptor& descriptor,
809 Optional<std::string&> reasonIfUnsupported) const
810{
811 ignore_unused(output);
812 ignore_unused(descriptor);
813 return IsSupportedForDataTypeRef(reasonIfUnsupported,
814 input.GetDataType(),
815 &TrueFunc<>,
816 &TrueFunc<>);
817}
818
819bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
820 const TensorInfo& output,
821 const Pooling2dDescriptor& descriptor,
822 Optional<std::string&> reasonIfUnsupported) const
823{
824 ignore_unused(output);
825 ignore_unused(descriptor);
826 return IsSupportedForDataTypeRef(reasonIfUnsupported,
827 input.GetDataType(),
828 &TrueFunc<>,
829 &TrueFunc<>);
830}
831
Derek Lamberti5f400d62019-03-25 15:41:58 +0000832bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
833 const TensorInfo& output,
834 Optional<std::string&> reasonIfUnsupported) const
835{
836 bool supported = true;
837
838 // Define supported output types.
839 std::array<DataType,2> supportedInputTypes = {
840 DataType::Float32,
841 };
842
843 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
844 "Reference quantize: input type not supported.");
845
846 // Define supported output types.
847 std::array<DataType,2> supportedOutputTypes = {
848 DataType::QuantisedAsymm8,
849 DataType::QuantisedSymm16
850 };
851 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
852 "Reference quantize: output type not supported.");
853
854 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
855 "Reference quantize: input and output shapes have different num total elements.");
856
857 return supported;
858}
859
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100860bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000861 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100862 Optional<std::string&> reasonIfUnsupported) const
863{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000864 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100865 return IsSupportedForDataTypeRef(reasonIfUnsupported,
866 input.GetDataType(),
867 &TrueFunc<>,
868 &TrueFunc<>);
869}
870
871bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000872 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100873 Optional<std::string&> reasonIfUnsupported) const
874{
Sadik Armaganc625f002018-12-17 11:32:16 +0000875 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100876 return IsSupportedForDataTypeRef(reasonIfUnsupported,
877 input.GetDataType(),
878 &TrueFunc<>,
879 &TrueFunc<>);
880}
881
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000882bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
883 const TensorInfo& output,
884 Optional<std::string&> reasonIfUnsupported) const
885{
886 ignore_unused(output);
887 return IsSupportedForDataTypeRef(reasonIfUnsupported,
888 input.GetDataType(),
889 &TrueFunc<>,
890 &FalseFuncU8<>);
891}
892
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100893bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
894 const TensorInfo& output,
895 const SoftmaxDescriptor& descriptor,
896 Optional<std::string&> reasonIfUnsupported) const
897{
898 ignore_unused(output);
899 ignore_unused(descriptor);
900 return IsSupportedForDataTypeRef(reasonIfUnsupported,
901 input.GetDataType(),
902 &TrueFunc<>,
903 &TrueFunc<>);
904}
905
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000906bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
907 const TensorInfo& output,
908 const SpaceToBatchNdDescriptor& descriptor,
909 Optional<std::string&> reasonIfUnsupported) const
910{
911 ignore_unused(output);
912 ignore_unused(descriptor);
913 return IsSupportedForDataTypeRef(reasonIfUnsupported,
914 input.GetDataType(),
915 &TrueFunc<>,
916 &TrueFunc<>);
917}
918
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100919bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
920 const ViewsDescriptor& descriptor,
921 Optional<std::string&> reasonIfUnsupported) const
922{
923 ignore_unused(descriptor);
924 return IsSupportedForDataTypeRef(reasonIfUnsupported,
925 input.GetDataType(),
926 &TrueFunc<>,
927 &TrueFunc<>);
928}
929
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000930bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
931 const TensorInfo& output,
932 const StridedSliceDescriptor& descriptor,
933 Optional<std::string&> reasonIfUnsupported) const
934{
935 ignore_unused(output);
936 ignore_unused(descriptor);
937 return IsSupportedForDataTypeRef(reasonIfUnsupported,
938 input.GetDataType(),
939 &TrueFunc<>,
940 &TrueFunc<>);
941}
942
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100943bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
944 const TensorInfo& input1,
945 const TensorInfo& output,
946 Optional<std::string&> reasonIfUnsupported) const
947{
Sadik Armagan2999a022019-04-09 14:20:12 +0100948 bool supported = true;
949
950 std::array<DataType,3> supportedTypes = {
951 DataType::Float32,
952 DataType::QuantisedAsymm8,
953 DataType::QuantisedSymm16
954 };
955
956 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
957 "Reference subtraction: input 0 is not a supported type.");
958
959 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
960 "Reference subtraction: input 1 is not a supported type.");
961
962 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
963 "Reference subtraction: output is not a supported type.");
964
965 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
966 "Reference subtraction: input 0 and Input 1 types are mismatched");
967
968 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
969 "Reference subtraction: input and output types are mismatched");
970
971 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
972 "Reference subtraction: shapes are not suitable for implicit broadcast.");
973
974 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100975}
976
arovir011c7c81b2018-10-08 11:34:28 +0100977} // namespace armnn