blob: d2cf6f904af94618c87f9d7441b9a9d0705756a5 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010015
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
19#include <algorithm>
20#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
Derek Lamberti50db4e82019-03-13 14:16:15 +000049
50namespace
51{
52template<typename F>
53bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
54{
55 bool supported = rule();
56 if (!supported && reason)
57 {
58 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
59 }
60 return supported;
61}
62
63struct Rule
64{
65 bool operator()() const
66 {
67 return m_Res;
68 }
69
70 bool m_Res = true;
71};
72
Derek Lamberti2a434a82019-03-20 13:07:57 +000073template<typename T>
74bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000075{
76 return true;
77}
78
79template<typename T, typename... Rest>
80bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
81{
82 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
83
Derek Lamberti2a434a82019-03-20 13:07:57 +000084 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000085}
86
87struct TypesAreEqual : public Rule
88{
89 template<typename ... Ts>
90 TypesAreEqual(const Ts&... ts)
91 {
92 m_Res = AllTypesAreEqualImpl(ts...);
93 }
94};
95
96struct QuantizationParametersAreEqual : public Rule
97{
98 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
99 {
100 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
101 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
102 }
103};
104
105struct TypeAnyOf : public Rule
106{
107 template<typename Container>
108 TypeAnyOf(const TensorInfo& info, const Container& c)
109 {
110 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
111 {
112 return dt == info.GetDataType();
113 });
114 }
115};
116
117struct ShapesAreSameRank : public Rule
118{
119 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
120 {
121 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
122 }
123};
124
Derek Lamberti5f400d62019-03-25 15:41:58 +0000125struct ShapesAreSameTotalSize : public Rule
126{
127 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
128 {
129 m_Res = info0.GetNumElements() == info1.GetNumElements();
130 }
131};
132
Derek Lamberti50db4e82019-03-13 14:16:15 +0000133struct ShapesAreBroadcastCompatible : public Rule
134{
135 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
136 {
137 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
138 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
139 return sizeIn;
140 }
141
142 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
143 {
144 const TensorShape& shape0 = in0.GetShape();
145 const TensorShape& shape1 = in1.GetShape();
146 const TensorShape& outShape = out.GetShape();
147
148 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
149 {
150 unsigned int sizeOut = outShape[i];
151 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
152 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
153
154 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
155 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
156 }
157 }
158};
159} // namespace
160
161
arovir011c7c81b2018-10-08 11:34:28 +0100162bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
163 const TensorInfo& output,
164 const ActivationDescriptor& descriptor,
165 Optional<std::string&> reasonIfUnsupported) const
166{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000167 bool supported = true;
168
169 // Define supported types.
170 std::array<DataType,2> supportedTypes = {
171 DataType::Float32,
172 DataType::QuantisedAsymm8
173 };
174
175 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
176 "Reference activation: input type not supported.");
177
178 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
179 "Reference activation: output type not supported.");
180
181 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
182 "Reference activation: input and output types mismatched.");
183
184 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
185 "Reference activation: input and output shapes are of different rank.");
186
187
188 struct ActivationFunctionSupported : public Rule
189 {
190 ActivationFunctionSupported(const ActivationDescriptor& desc)
191 {
192 switch(desc.m_Function)
193 {
194 case ActivationFunction::Abs:
195 case ActivationFunction::BoundedReLu:
196 case ActivationFunction::LeakyReLu:
197 case ActivationFunction::Linear:
198 case ActivationFunction::ReLu:
199 case ActivationFunction::Sigmoid:
200 case ActivationFunction::SoftReLu:
201 case ActivationFunction::Sqrt:
202 case ActivationFunction::Square:
203 case ActivationFunction::TanH:
204 {
205 m_Res = true;
206 break;
207 }
208 default:
209 {
210 m_Res = false;
211 break;
212 }
213 }
214 }
215 };
216
217 // Function is supported
218 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
219 "Reference activation: function not supported.");
220
221 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100222}
223
224bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
225 const TensorInfo& input1,
226 const TensorInfo& output,
227 Optional<std::string&> reasonIfUnsupported) const
228{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000229 bool supported = true;
230
231 std::array<DataType,2> supportedTypes = {
232 DataType::Float32,
233 DataType::QuantisedAsymm8
234 };
235
236 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
237 "Reference addition: input 0 is not a supported type.");
238
239 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
240 "Reference addition: input 1 is not a supported type.");
241
242 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
243 "Reference addition: output is not a supported type.");
244
245 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
246 "Reference addition: input 0 and Input 1 types are mismatched");
247
248 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
249 "Reference addition: input and output types are mismatched");
250
251 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
252 "Reference addition: shapes are not suitable for implicit broadcast.");
253
254 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100255}
256
257bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
258 const TensorInfo& output,
259 const TensorInfo& mean,
260 const TensorInfo& var,
261 const TensorInfo& beta,
262 const TensorInfo& gamma,
263 const BatchNormalizationDescriptor& descriptor,
264 Optional<std::string&> reasonIfUnsupported) const
265{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100266 ignore_unused(output);
267 ignore_unused(mean);
268 ignore_unused(var);
269 ignore_unused(beta);
270 ignore_unused(gamma);
271 ignore_unused(descriptor);
272 return IsSupportedForDataTypeRef(reasonIfUnsupported,
273 input.GetDataType(),
274 &TrueFunc<>,
275 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100276}
277
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000278bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
279 const TensorInfo& output,
280 const BatchToSpaceNdDescriptor& descriptor,
281 Optional<std::string&> reasonIfUnsupported) const
282{
283 ignore_unused(descriptor);
284 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
285 input.GetDataType(),
286 &TrueFunc<>,
287 &TrueFunc<>) &&
288 IsSupportedForDataTypeRef(reasonIfUnsupported,
289 output.GetDataType(),
290 &TrueFunc<>,
291 &TrueFunc<>));
292}
293
arovir011c7c81b2018-10-08 11:34:28 +0100294bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
295 Optional<std::string&> reasonIfUnsupported) const
296{
narpra01db2b1602019-01-23 15:23:11 +0000297 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
298 output.GetDataType(),
299 &FalseFunc<>,
300 &TrueFunc<>,
301 &TrueFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000302 &TrueFunc<>,
303 &FalseFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100304}
305
306bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
307 const TensorInfo& output,
308 Optional<std::string&> reasonIfUnsupported) const
309{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100310 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
311 input.GetDataType(),
312 &TrueFunc<>,
313 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000314 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000315 &FalseFuncI32<>,
316 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100317 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
318 output.GetDataType(),
319 &FalseOutputFuncF16<>,
320 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000321 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000322 &FalseFuncI32<>,
323 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100324}
325
326bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
327 const TensorInfo& output,
328 Optional<std::string&> reasonIfUnsupported) const
329{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100330 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
331 input.GetDataType(),
332 &FalseInputFuncF16<>,
333 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000334 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000335 &FalseFuncI32<>,
336 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100337 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
338 output.GetDataType(),
339 &TrueFunc<>,
340 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000341 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000342 &FalseFuncI32<>,
343 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100344}
345
346bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
347 const TensorInfo& output,
348 const Convolution2dDescriptor& descriptor,
349 const TensorInfo& weights,
350 const Optional<TensorInfo>& biases,
351 Optional<std::string&> reasonIfUnsupported) const
352{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100353 ignore_unused(output);
354 ignore_unused(descriptor);
355 ignore_unused(weights);
356 ignore_unused(biases);
357 return IsSupportedForDataTypeRef(reasonIfUnsupported,
358 input.GetDataType(),
359 &TrueFunc<>,
360 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100361}
362
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000363bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
364 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000365 Optional<std::string&> reasonIfUnsupported) const
366{
367 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000368 return IsSupportedForDataTypeRef(reasonIfUnsupported,
369 input.GetDataType(),
370 &TrueFunc<>,
371 &TrueFunc<>);
372}
373
arovir011c7c81b2018-10-08 11:34:28 +0100374bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
375 const TensorInfo& output,
376 const DepthwiseConvolution2dDescriptor& descriptor,
377 const TensorInfo& weights,
378 const Optional<TensorInfo>& biases,
379 Optional<std::string&> reasonIfUnsupported) const
380{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100381 ignore_unused(output);
382 ignore_unused(descriptor);
383 ignore_unused(weights);
384 ignore_unused(biases);
385 return IsSupportedForDataTypeRef(reasonIfUnsupported,
386 input.GetDataType(),
387 &TrueFunc<>,
388 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100389}
390
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000391bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
392 const TensorInfo& output,
393 Optional<std::string&> reasonIfUnsupported) const
394{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100395 bool supported = true;
396
397 std::array<DataType,2> supportedInputTypes = {
398 DataType::QuantisedAsymm8,
399 DataType::QuantisedSymm16
400 };
401
402 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
403 "Reference dequantize: input type not supported.");
404
405 std::array<DataType,2> supportedOutputTypes = {
406 DataType::Float32,
407 };
408
409 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
410 "Reference dequantize: output type not supported.");
411
412 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
413 "Reference dequantize: input and output shapes have different num total elements.");
414
415 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000416}
417
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000418bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
419 const armnn::TensorInfo& input1,
420 const armnn::DetectionPostProcessDescriptor& descriptor,
421 armnn::Optional<std::string&> reasonIfUnsupported) const
422{
423 ignore_unused(input1);
424 return IsSupportedForDataTypeRef(reasonIfUnsupported,
425 input0.GetDataType(),
426 &TrueFunc<>,
427 &TrueFunc<>);
428}
429
arovir011c7c81b2018-10-08 11:34:28 +0100430bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
431 const TensorInfo& input1,
432 const TensorInfo& output,
433 Optional<std::string&> reasonIfUnsupported) const
434{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100435 ignore_unused(input1);
436 ignore_unused(output);
437 return IsSupportedForDataTypeRef(reasonIfUnsupported,
438 input0.GetDataType(),
439 &TrueFunc<>,
440 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100441}
442
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000443bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
444 const TensorInfo& input1,
445 const TensorInfo& output,
446 Optional<std::string&> reasonIfUnsupported) const
447{
448 ignore_unused(input0);
449 ignore_unused(input1);
450 ignore_unused(output);
451 ignore_unused(reasonIfUnsupported);
452 return IsSupportedForDataTypeRef(reasonIfUnsupported,
453 input0.GetDataType(),
454 &TrueFunc<>,
455 &TrueFunc<>);
456}
457
arovir011c7c81b2018-10-08 11:34:28 +0100458bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
459 const FakeQuantizationDescriptor& descriptor,
460 Optional<std::string&> reasonIfUnsupported) const
461{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100462 ignore_unused(descriptor);
463 return IsSupportedForDataTypeRef(reasonIfUnsupported,
464 input.GetDataType(),
465 &TrueFunc<>,
466 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100467}
468
469bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
470 const TensorInfo& output,
471 Optional<std::string&> reasonIfUnsupported) const
472{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100473 ignore_unused(output);
474 return IsSupportedForDataTypeRef(reasonIfUnsupported,
475 input.GetDataType(),
476 &TrueFunc<>,
477 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100478}
479
480bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
481 const TensorInfo& output,
482 const TensorInfo& weights,
483 const TensorInfo& biases,
484 const FullyConnectedDescriptor& descriptor,
485 Optional<std::string&> reasonIfUnsupported) const
486{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100487 ignore_unused(output);
488 ignore_unused(weights);
489 ignore_unused(biases);
490 ignore_unused(descriptor);
491 return IsSupportedForDataTypeRef(reasonIfUnsupported,
492 input.GetDataType(),
493 &TrueFunc<>,
494 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100495}
496
narpra014951d842019-01-18 16:53:53 +0000497bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
498 const armnn::TensorInfo& input1,
499 const armnn::TensorInfo& output,
500 armnn::Optional<std::string&> reasonIfUnsupported) const
501{
502 ignore_unused(input1);
503 ignore_unused(output);
504 return IsSupportedForDataTypeRef(reasonIfUnsupported,
505 input0.GetDataType(),
506 &TrueFunc<>,
507 &TrueFunc<>);
508}
509
FrancisMurtagh878f0232018-12-19 10:56:15 +0000510bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
511 const TensorInfo& input1,
512 const TensorInfo& output,
513 Optional<std::string&> reasonIfUnsupported) const
514{
515 ignore_unused(input0);
516 ignore_unused(input1);
517 ignore_unused(output);
518 ignore_unused(reasonIfUnsupported);
519 return IsSupportedForDataTypeRef(reasonIfUnsupported,
520 input0.GetDataType(),
521 &TrueFunc<>,
522 &TrueFunc<>);
523}
524
arovir011c7c81b2018-10-08 11:34:28 +0100525bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
526 Optional<std::string&> reasonIfUnsupported) const
527{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100528 return IsSupportedForDataTypeRef(reasonIfUnsupported,
529 input.GetDataType(),
530 &TrueFunc<>,
531 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100532}
533
534bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
535 const TensorInfo& output,
536 const L2NormalizationDescriptor& descriptor,
537 Optional<std::string&> reasonIfUnsupported) const
538{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100539 ignore_unused(output);
540 ignore_unused(descriptor);
541 return IsSupportedForDataTypeRef(reasonIfUnsupported,
542 input.GetDataType(),
543 &TrueFunc<>,
544 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100545}
546
547bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
548 const TensorInfo& outputStateIn,
549 const TensorInfo& cellStateIn,
550 const TensorInfo& scratchBuffer,
551 const TensorInfo& outputStateOut,
552 const TensorInfo& cellStateOut,
553 const TensorInfo& output,
554 const LstmDescriptor& descriptor,
555 const TensorInfo& inputToForgetWeights,
556 const TensorInfo& inputToCellWeights,
557 const TensorInfo& inputToOutputWeights,
558 const TensorInfo& recurrentToForgetWeights,
559 const TensorInfo& recurrentToCellWeights,
560 const TensorInfo& recurrentToOutputWeights,
561 const TensorInfo& forgetGateBias,
562 const TensorInfo& cellBias,
563 const TensorInfo& outputGateBias,
564 const TensorInfo* inputToInputWeights,
565 const TensorInfo* recurrentToInputWeights,
566 const TensorInfo* cellToInputWeights,
567 const TensorInfo* inputGateBias,
568 const TensorInfo* projectionWeights,
569 const TensorInfo* projectionBias,
570 const TensorInfo* cellToForgetWeights,
571 const TensorInfo* cellToOutputWeights,
572 Optional<std::string&> reasonIfUnsupported) const
573{
telsoa01c577f2c2018-08-31 09:22:23 +0100574 ignore_unused(outputStateIn);
575 ignore_unused(cellStateIn);
576 ignore_unused(scratchBuffer);
577 ignore_unused(outputStateOut);
578 ignore_unused(cellStateOut);
579 ignore_unused(output);
580 ignore_unused(descriptor);
581 ignore_unused(inputToForgetWeights);
582 ignore_unused(inputToCellWeights);
583 ignore_unused(inputToOutputWeights);
584 ignore_unused(recurrentToForgetWeights);
585 ignore_unused(recurrentToCellWeights);
586 ignore_unused(recurrentToOutputWeights);
587 ignore_unused(forgetGateBias);
588 ignore_unused(cellBias);
589 ignore_unused(outputGateBias);
590 ignore_unused(inputToInputWeights);
591 ignore_unused(recurrentToInputWeights);
592 ignore_unused(cellToInputWeights);
593 ignore_unused(inputGateBias);
594 ignore_unused(projectionWeights);
595 ignore_unused(projectionBias);
596 ignore_unused(cellToForgetWeights);
597 ignore_unused(cellToOutputWeights);
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000598 return IsSupportedForDataTypeRef(reasonIfUnsupported,
599 input.GetDataType(),
600 &TrueFunc<>,
601 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100602}
603
saoste012df12b32018-11-28 16:57:20 +0000604bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
605 const TensorInfo& input1,
606 const TensorInfo& output,
607 Optional<std::string&> reasonIfUnsupported) const
608{
609 ignore_unused(input1);
610 ignore_unused(output);
611 return IsSupportedForDataTypeRef(reasonIfUnsupported,
612 input0.GetDataType(),
613 &TrueFunc<>,
614 &TrueFunc<>);
615}
616
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100617bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
618 const TensorInfo& output,
619 const MeanDescriptor& descriptor,
620 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100621{
narpra011e4c31d2018-09-28 11:07:51 +0100622 ignore_unused(output);
623 ignore_unused(descriptor);
624 return IsSupportedForDataTypeRef(reasonIfUnsupported,
625 input.GetDataType(),
626 &TrueFunc<>,
627 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100628}
629
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100630bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000631 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100632 const OriginsDescriptor& descriptor,
633 Optional<std::string&> reasonIfUnsupported) const
634{
635 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000636 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100637 return IsSupportedForDataTypeRef(reasonIfUnsupported,
638 inputs[0]->GetDataType(),
639 &TrueFunc<>,
640 &TrueFunc<>);
641}
642
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000643bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
644 const TensorInfo &output,
645 Optional<std::string &> reasonIfUnsupported) const
646{
647 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000648 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
649 input.GetDataType(),
650 &TrueFunc<>,
651 &TrueFunc<>,
652 &TrueFunc<>,
653 &FalseFuncI32<>,
654 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000655}
656
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000657bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
658 const TensorInfo& input1,
659 const TensorInfo& output,
660 Optional<std::string&> reasonIfUnsupported) const
661{
662 ignore_unused(input1);
663 ignore_unused(output);
664 return IsSupportedForDataTypeRef(reasonIfUnsupported,
665 input0.GetDataType(),
666 &TrueFunc<>,
667 &TrueFunc<>);
668}
669
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100670bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
671 const TensorInfo& input1,
672 const TensorInfo& output,
673 Optional<std::string&> reasonIfUnsupported) const
674{
675 ignore_unused(input1);
676 ignore_unused(output);
677 return IsSupportedForDataTypeRef(reasonIfUnsupported,
678 input0.GetDataType(),
679 &TrueFunc<>,
680 &TrueFunc<>);
681}
682
683bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
684 const TensorInfo& output,
685 const NormalizationDescriptor& descriptor,
686 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100687{
688 ignore_unused(output);
689 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100690 return IsSupportedForDataTypeRef(reasonIfUnsupported,
691 input.GetDataType(),
692 &TrueFunc<>,
693 &FalseFuncU8<>);
694}
695
696bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
697 Optional<std::string&> reasonIfUnsupported) const
698{
kevmay012b4d88e2019-01-24 14:05:09 +0000699 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
700 output.GetDataType(),
701 &TrueFunc<>,
702 &TrueFunc<>,
703 &TrueFunc<>,
704 &FalseFuncI32<>,
705 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100706}
707
708bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
709 const TensorInfo& output,
710 const PadDescriptor& descriptor,
711 Optional<std::string&> reasonIfUnsupported) const
712{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100713 ignore_unused(output);
714 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000715 return IsSupportedForDataTypeRef(reasonIfUnsupported,
716 input.GetDataType(),
717 &TrueFunc<>,
718 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100719}
720
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100721bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
722 const TensorInfo& output,
723 const PermuteDescriptor& descriptor,
724 Optional<std::string&> reasonIfUnsupported) const
725{
726 ignore_unused(output);
727 ignore_unused(descriptor);
728 return IsSupportedForDataTypeRef(reasonIfUnsupported,
729 input.GetDataType(),
730 &TrueFunc<>,
731 &TrueFunc<>);
732}
733
734bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
735 const TensorInfo& output,
736 const Pooling2dDescriptor& descriptor,
737 Optional<std::string&> reasonIfUnsupported) const
738{
739 ignore_unused(output);
740 ignore_unused(descriptor);
741 return IsSupportedForDataTypeRef(reasonIfUnsupported,
742 input.GetDataType(),
743 &TrueFunc<>,
744 &TrueFunc<>);
745}
746
Derek Lamberti5f400d62019-03-25 15:41:58 +0000747bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
748 const TensorInfo& output,
749 Optional<std::string&> reasonIfUnsupported) const
750{
751 bool supported = true;
752
753 // Define supported output types.
754 std::array<DataType,2> supportedInputTypes = {
755 DataType::Float32,
756 };
757
758 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
759 "Reference quantize: input type not supported.");
760
761 // Define supported output types.
762 std::array<DataType,2> supportedOutputTypes = {
763 DataType::QuantisedAsymm8,
764 DataType::QuantisedSymm16
765 };
766 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
767 "Reference quantize: output type not supported.");
768
769 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
770 "Reference quantize: input and output shapes have different num total elements.");
771
772 return supported;
773}
774
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100775bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000776 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100777 Optional<std::string&> reasonIfUnsupported) const
778{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000779 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100780 return IsSupportedForDataTypeRef(reasonIfUnsupported,
781 input.GetDataType(),
782 &TrueFunc<>,
783 &TrueFunc<>);
784}
785
786bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000787 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100788 Optional<std::string&> reasonIfUnsupported) const
789{
Sadik Armaganc625f002018-12-17 11:32:16 +0000790 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100791 return IsSupportedForDataTypeRef(reasonIfUnsupported,
792 input.GetDataType(),
793 &TrueFunc<>,
794 &TrueFunc<>);
795}
796
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000797bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
798 const TensorInfo& output,
799 Optional<std::string&> reasonIfUnsupported) const
800{
801 ignore_unused(output);
802 return IsSupportedForDataTypeRef(reasonIfUnsupported,
803 input.GetDataType(),
804 &TrueFunc<>,
805 &FalseFuncU8<>);
806}
807
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100808bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
809 const TensorInfo& output,
810 const SoftmaxDescriptor& descriptor,
811 Optional<std::string&> reasonIfUnsupported) const
812{
813 ignore_unused(output);
814 ignore_unused(descriptor);
815 return IsSupportedForDataTypeRef(reasonIfUnsupported,
816 input.GetDataType(),
817 &TrueFunc<>,
818 &TrueFunc<>);
819}
820
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000821bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
822 const TensorInfo& output,
823 const SpaceToBatchNdDescriptor& descriptor,
824 Optional<std::string&> reasonIfUnsupported) const
825{
826 ignore_unused(output);
827 ignore_unused(descriptor);
828 return IsSupportedForDataTypeRef(reasonIfUnsupported,
829 input.GetDataType(),
830 &TrueFunc<>,
831 &TrueFunc<>);
832}
833
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100834bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
835 const ViewsDescriptor& descriptor,
836 Optional<std::string&> reasonIfUnsupported) const
837{
838 ignore_unused(descriptor);
839 return IsSupportedForDataTypeRef(reasonIfUnsupported,
840 input.GetDataType(),
841 &TrueFunc<>,
842 &TrueFunc<>);
843}
844
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000845bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
846 const TensorInfo& output,
847 const StridedSliceDescriptor& descriptor,
848 Optional<std::string&> reasonIfUnsupported) const
849{
850 ignore_unused(output);
851 ignore_unused(descriptor);
852 return IsSupportedForDataTypeRef(reasonIfUnsupported,
853 input.GetDataType(),
854 &TrueFunc<>,
855 &TrueFunc<>);
856}
857
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100858bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
859 const TensorInfo& input1,
860 const TensorInfo& output,
861 Optional<std::string&> reasonIfUnsupported) const
862{
863 ignore_unused(input1);
864 ignore_unused(output);
865 return IsSupportedForDataTypeRef(reasonIfUnsupported,
866 input0.GetDataType(),
867 &TrueFunc<>,
868 &TrueFunc<>);
869}
870
arovir011c7c81b2018-10-08 11:34:28 +0100871} // namespace armnn