blob: cdc6acae7ff783bdc82070ccb16707830ba6be8c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010015
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
19#include <algorithm>
20#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
Derek Lamberti50db4e82019-03-13 14:16:15 +000049
50namespace
51{
52template<typename F>
53bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
54{
55 bool supported = rule();
56 if (!supported && reason)
57 {
58 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
59 }
60 return supported;
61}
62
63struct Rule
64{
65 bool operator()() const
66 {
67 return m_Res;
68 }
69
70 bool m_Res = true;
71};
72
73template<class none = void>
74bool AllTypesAreEqualImpl()
75{
76 return true;
77}
78
79template<typename T, typename... Rest>
80bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
81{
82 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
83
84 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(rest...);
85}
86
87struct TypesAreEqual : public Rule
88{
89 template<typename ... Ts>
90 TypesAreEqual(const Ts&... ts)
91 {
92 m_Res = AllTypesAreEqualImpl(ts...);
93 }
94};
95
96struct QuantizationParametersAreEqual : public Rule
97{
98 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
99 {
100 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
101 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
102 }
103};
104
105struct TypeAnyOf : public Rule
106{
107 template<typename Container>
108 TypeAnyOf(const TensorInfo& info, const Container& c)
109 {
110 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
111 {
112 return dt == info.GetDataType();
113 });
114 }
115};
116
117struct ShapesAreSameRank : public Rule
118{
119 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
120 {
121 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
122 }
123};
124
125struct ShapesAreBroadcastCompatible : public Rule
126{
127 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
128 {
129 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
130 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
131 return sizeIn;
132 }
133
134 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
135 {
136 const TensorShape& shape0 = in0.GetShape();
137 const TensorShape& shape1 = in1.GetShape();
138 const TensorShape& outShape = out.GetShape();
139
140 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
141 {
142 unsigned int sizeOut = outShape[i];
143 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
144 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
145
146 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
147 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
148 }
149 }
150};
151} // namespace
152
153
arovir011c7c81b2018-10-08 11:34:28 +0100154bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
155 const TensorInfo& output,
156 const ActivationDescriptor& descriptor,
157 Optional<std::string&> reasonIfUnsupported) const
158{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000159 bool supported = true;
160
161 // Define supported types.
162 std::array<DataType,2> supportedTypes = {
163 DataType::Float32,
164 DataType::QuantisedAsymm8
165 };
166
167 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
168 "Reference activation: input type not supported.");
169
170 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
171 "Reference activation: output type not supported.");
172
173 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
174 "Reference activation: input and output types mismatched.");
175
176 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
177 "Reference activation: input and output shapes are of different rank.");
178
179
180 struct ActivationFunctionSupported : public Rule
181 {
182 ActivationFunctionSupported(const ActivationDescriptor& desc)
183 {
184 switch(desc.m_Function)
185 {
186 case ActivationFunction::Abs:
187 case ActivationFunction::BoundedReLu:
188 case ActivationFunction::LeakyReLu:
189 case ActivationFunction::Linear:
190 case ActivationFunction::ReLu:
191 case ActivationFunction::Sigmoid:
192 case ActivationFunction::SoftReLu:
193 case ActivationFunction::Sqrt:
194 case ActivationFunction::Square:
195 case ActivationFunction::TanH:
196 {
197 m_Res = true;
198 break;
199 }
200 default:
201 {
202 m_Res = false;
203 break;
204 }
205 }
206 }
207 };
208
209 // Function is supported
210 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
211 "Reference activation: function not supported.");
212
213 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100214}
215
216bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
217 const TensorInfo& input1,
218 const TensorInfo& output,
219 Optional<std::string&> reasonIfUnsupported) const
220{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000221 bool supported = true;
222
223 std::array<DataType,2> supportedTypes = {
224 DataType::Float32,
225 DataType::QuantisedAsymm8
226 };
227
228 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
229 "Reference addition: input 0 is not a supported type.");
230
231 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
232 "Reference addition: input 1 is not a supported type.");
233
234 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
235 "Reference addition: output is not a supported type.");
236
237 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
238 "Reference addition: input 0 and Input 1 types are mismatched");
239
240 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
241 "Reference addition: input and output types are mismatched");
242
243 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
244 "Reference addition: shapes are not suitable for implicit broadcast.");
245
246 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100247}
248
249bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
250 const TensorInfo& output,
251 const TensorInfo& mean,
252 const TensorInfo& var,
253 const TensorInfo& beta,
254 const TensorInfo& gamma,
255 const BatchNormalizationDescriptor& descriptor,
256 Optional<std::string&> reasonIfUnsupported) const
257{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100258 ignore_unused(output);
259 ignore_unused(mean);
260 ignore_unused(var);
261 ignore_unused(beta);
262 ignore_unused(gamma);
263 ignore_unused(descriptor);
264 return IsSupportedForDataTypeRef(reasonIfUnsupported,
265 input.GetDataType(),
266 &TrueFunc<>,
267 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100268}
269
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000270bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
271 const TensorInfo& output,
272 const BatchToSpaceNdDescriptor& descriptor,
273 Optional<std::string&> reasonIfUnsupported) const
274{
275 ignore_unused(descriptor);
276 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
277 input.GetDataType(),
278 &TrueFunc<>,
279 &TrueFunc<>) &&
280 IsSupportedForDataTypeRef(reasonIfUnsupported,
281 output.GetDataType(),
282 &TrueFunc<>,
283 &TrueFunc<>));
284}
285
arovir011c7c81b2018-10-08 11:34:28 +0100286bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
287 Optional<std::string&> reasonIfUnsupported) const
288{
narpra01db2b1602019-01-23 15:23:11 +0000289 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
290 output.GetDataType(),
291 &FalseFunc<>,
292 &TrueFunc<>,
293 &TrueFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000294 &TrueFunc<>,
295 &FalseFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100296}
297
298bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
299 const TensorInfo& output,
300 Optional<std::string&> reasonIfUnsupported) const
301{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100302 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
303 input.GetDataType(),
304 &TrueFunc<>,
305 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000306 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000307 &FalseFuncI32<>,
308 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100309 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
310 output.GetDataType(),
311 &FalseOutputFuncF16<>,
312 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000313 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000314 &FalseFuncI32<>,
315 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100316}
317
318bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
319 const TensorInfo& output,
320 Optional<std::string&> reasonIfUnsupported) const
321{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100322 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
323 input.GetDataType(),
324 &FalseInputFuncF16<>,
325 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000326 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000327 &FalseFuncI32<>,
328 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100329 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
330 output.GetDataType(),
331 &TrueFunc<>,
332 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000333 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000334 &FalseFuncI32<>,
335 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100336}
337
338bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
339 const TensorInfo& output,
340 const Convolution2dDescriptor& descriptor,
341 const TensorInfo& weights,
342 const Optional<TensorInfo>& biases,
343 Optional<std::string&> reasonIfUnsupported) const
344{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100345 ignore_unused(output);
346 ignore_unused(descriptor);
347 ignore_unused(weights);
348 ignore_unused(biases);
349 return IsSupportedForDataTypeRef(reasonIfUnsupported,
350 input.GetDataType(),
351 &TrueFunc<>,
352 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100353}
354
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000355bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
356 const TensorInfo& output,
357 const DebugDescriptor& descriptor,
358 Optional<std::string&> reasonIfUnsupported) const
359{
360 ignore_unused(output);
361 ignore_unused(descriptor);
362 return IsSupportedForDataTypeRef(reasonIfUnsupported,
363 input.GetDataType(),
364 &TrueFunc<>,
365 &TrueFunc<>);
366}
367
arovir011c7c81b2018-10-08 11:34:28 +0100368bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
369 const TensorInfo& output,
370 const DepthwiseConvolution2dDescriptor& descriptor,
371 const TensorInfo& weights,
372 const Optional<TensorInfo>& biases,
373 Optional<std::string&> reasonIfUnsupported) const
374{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100375 ignore_unused(output);
376 ignore_unused(descriptor);
377 ignore_unused(weights);
378 ignore_unused(biases);
379 return IsSupportedForDataTypeRef(reasonIfUnsupported,
380 input.GetDataType(),
381 &TrueFunc<>,
382 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100383}
384
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000385bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
386 const armnn::TensorInfo& input1,
387 const armnn::DetectionPostProcessDescriptor& descriptor,
388 armnn::Optional<std::string&> reasonIfUnsupported) const
389{
390 ignore_unused(input1);
391 return IsSupportedForDataTypeRef(reasonIfUnsupported,
392 input0.GetDataType(),
393 &TrueFunc<>,
394 &TrueFunc<>);
395}
396
arovir011c7c81b2018-10-08 11:34:28 +0100397bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
398 const TensorInfo& input1,
399 const TensorInfo& output,
400 Optional<std::string&> reasonIfUnsupported) const
401{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100402 ignore_unused(input1);
403 ignore_unused(output);
404 return IsSupportedForDataTypeRef(reasonIfUnsupported,
405 input0.GetDataType(),
406 &TrueFunc<>,
407 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100408}
409
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000410bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
411 const TensorInfo& input1,
412 const TensorInfo& output,
413 Optional<std::string&> reasonIfUnsupported) const
414{
415 ignore_unused(input0);
416 ignore_unused(input1);
417 ignore_unused(output);
418 ignore_unused(reasonIfUnsupported);
419 return IsSupportedForDataTypeRef(reasonIfUnsupported,
420 input0.GetDataType(),
421 &TrueFunc<>,
422 &TrueFunc<>);
423}
424
arovir011c7c81b2018-10-08 11:34:28 +0100425bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
426 const FakeQuantizationDescriptor& descriptor,
427 Optional<std::string&> reasonIfUnsupported) const
428{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100429 ignore_unused(descriptor);
430 return IsSupportedForDataTypeRef(reasonIfUnsupported,
431 input.GetDataType(),
432 &TrueFunc<>,
433 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100434}
435
436bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
437 const TensorInfo& output,
438 Optional<std::string&> reasonIfUnsupported) const
439{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100440 ignore_unused(output);
441 return IsSupportedForDataTypeRef(reasonIfUnsupported,
442 input.GetDataType(),
443 &TrueFunc<>,
444 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100445}
446
447bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
448 const TensorInfo& output,
449 const TensorInfo& weights,
450 const TensorInfo& biases,
451 const FullyConnectedDescriptor& descriptor,
452 Optional<std::string&> reasonIfUnsupported) const
453{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100454 ignore_unused(output);
455 ignore_unused(weights);
456 ignore_unused(biases);
457 ignore_unused(descriptor);
458 return IsSupportedForDataTypeRef(reasonIfUnsupported,
459 input.GetDataType(),
460 &TrueFunc<>,
461 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100462}
463
narpra014951d842019-01-18 16:53:53 +0000464bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
465 const armnn::TensorInfo& input1,
466 const armnn::TensorInfo& output,
467 armnn::Optional<std::string&> reasonIfUnsupported) const
468{
469 ignore_unused(input1);
470 ignore_unused(output);
471 return IsSupportedForDataTypeRef(reasonIfUnsupported,
472 input0.GetDataType(),
473 &TrueFunc<>,
474 &TrueFunc<>);
475}
476
FrancisMurtagh878f0232018-12-19 10:56:15 +0000477bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
478 const TensorInfo& input1,
479 const TensorInfo& output,
480 Optional<std::string&> reasonIfUnsupported) const
481{
482 ignore_unused(input0);
483 ignore_unused(input1);
484 ignore_unused(output);
485 ignore_unused(reasonIfUnsupported);
486 return IsSupportedForDataTypeRef(reasonIfUnsupported,
487 input0.GetDataType(),
488 &TrueFunc<>,
489 &TrueFunc<>);
490}
491
arovir011c7c81b2018-10-08 11:34:28 +0100492bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
493 Optional<std::string&> reasonIfUnsupported) const
494{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100495 return IsSupportedForDataTypeRef(reasonIfUnsupported,
496 input.GetDataType(),
497 &TrueFunc<>,
498 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100499}
500
501bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
502 const TensorInfo& output,
503 const L2NormalizationDescriptor& descriptor,
504 Optional<std::string&> reasonIfUnsupported) const
505{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100506 ignore_unused(output);
507 ignore_unused(descriptor);
508 return IsSupportedForDataTypeRef(reasonIfUnsupported,
509 input.GetDataType(),
510 &TrueFunc<>,
511 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100512}
513
514bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
515 const TensorInfo& outputStateIn,
516 const TensorInfo& cellStateIn,
517 const TensorInfo& scratchBuffer,
518 const TensorInfo& outputStateOut,
519 const TensorInfo& cellStateOut,
520 const TensorInfo& output,
521 const LstmDescriptor& descriptor,
522 const TensorInfo& inputToForgetWeights,
523 const TensorInfo& inputToCellWeights,
524 const TensorInfo& inputToOutputWeights,
525 const TensorInfo& recurrentToForgetWeights,
526 const TensorInfo& recurrentToCellWeights,
527 const TensorInfo& recurrentToOutputWeights,
528 const TensorInfo& forgetGateBias,
529 const TensorInfo& cellBias,
530 const TensorInfo& outputGateBias,
531 const TensorInfo* inputToInputWeights,
532 const TensorInfo* recurrentToInputWeights,
533 const TensorInfo* cellToInputWeights,
534 const TensorInfo* inputGateBias,
535 const TensorInfo* projectionWeights,
536 const TensorInfo* projectionBias,
537 const TensorInfo* cellToForgetWeights,
538 const TensorInfo* cellToOutputWeights,
539 Optional<std::string&> reasonIfUnsupported) const
540{
telsoa01c577f2c2018-08-31 09:22:23 +0100541 ignore_unused(outputStateIn);
542 ignore_unused(cellStateIn);
543 ignore_unused(scratchBuffer);
544 ignore_unused(outputStateOut);
545 ignore_unused(cellStateOut);
546 ignore_unused(output);
547 ignore_unused(descriptor);
548 ignore_unused(inputToForgetWeights);
549 ignore_unused(inputToCellWeights);
550 ignore_unused(inputToOutputWeights);
551 ignore_unused(recurrentToForgetWeights);
552 ignore_unused(recurrentToCellWeights);
553 ignore_unused(recurrentToOutputWeights);
554 ignore_unused(forgetGateBias);
555 ignore_unused(cellBias);
556 ignore_unused(outputGateBias);
557 ignore_unused(inputToInputWeights);
558 ignore_unused(recurrentToInputWeights);
559 ignore_unused(cellToInputWeights);
560 ignore_unused(inputGateBias);
561 ignore_unused(projectionWeights);
562 ignore_unused(projectionBias);
563 ignore_unused(cellToForgetWeights);
564 ignore_unused(cellToOutputWeights);
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000565 return IsSupportedForDataTypeRef(reasonIfUnsupported,
566 input.GetDataType(),
567 &TrueFunc<>,
568 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100569}
570
saoste012df12b32018-11-28 16:57:20 +0000571bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
572 const TensorInfo& input1,
573 const TensorInfo& output,
574 Optional<std::string&> reasonIfUnsupported) const
575{
576 ignore_unused(input1);
577 ignore_unused(output);
578 return IsSupportedForDataTypeRef(reasonIfUnsupported,
579 input0.GetDataType(),
580 &TrueFunc<>,
581 &TrueFunc<>);
582}
583
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100584bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
585 const TensorInfo& output,
586 const MeanDescriptor& descriptor,
587 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100588{
narpra011e4c31d2018-09-28 11:07:51 +0100589 ignore_unused(output);
590 ignore_unused(descriptor);
591 return IsSupportedForDataTypeRef(reasonIfUnsupported,
592 input.GetDataType(),
593 &TrueFunc<>,
594 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100595}
596
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100597bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000598 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100599 const OriginsDescriptor& descriptor,
600 Optional<std::string&> reasonIfUnsupported) const
601{
602 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000603 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100604 return IsSupportedForDataTypeRef(reasonIfUnsupported,
605 inputs[0]->GetDataType(),
606 &TrueFunc<>,
607 &TrueFunc<>);
608}
609
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000610bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
611 const TensorInfo &output,
612 Optional<std::string &> reasonIfUnsupported) const
613{
614 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000615 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
616 input.GetDataType(),
617 &TrueFunc<>,
618 &TrueFunc<>,
619 &TrueFunc<>,
620 &FalseFuncI32<>,
621 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000622}
623
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000624bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
625 const TensorInfo& input1,
626 const TensorInfo& output,
627 Optional<std::string&> reasonIfUnsupported) const
628{
629 ignore_unused(input1);
630 ignore_unused(output);
631 return IsSupportedForDataTypeRef(reasonIfUnsupported,
632 input0.GetDataType(),
633 &TrueFunc<>,
634 &TrueFunc<>);
635}
636
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100637bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
638 const TensorInfo& input1,
639 const TensorInfo& output,
640 Optional<std::string&> reasonIfUnsupported) const
641{
642 ignore_unused(input1);
643 ignore_unused(output);
644 return IsSupportedForDataTypeRef(reasonIfUnsupported,
645 input0.GetDataType(),
646 &TrueFunc<>,
647 &TrueFunc<>);
648}
649
650bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
651 const TensorInfo& output,
652 const NormalizationDescriptor& descriptor,
653 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100654{
655 ignore_unused(output);
656 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100657 return IsSupportedForDataTypeRef(reasonIfUnsupported,
658 input.GetDataType(),
659 &TrueFunc<>,
660 &FalseFuncU8<>);
661}
662
663bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
664 Optional<std::string&> reasonIfUnsupported) const
665{
kevmay012b4d88e2019-01-24 14:05:09 +0000666 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
667 output.GetDataType(),
668 &TrueFunc<>,
669 &TrueFunc<>,
670 &TrueFunc<>,
671 &FalseFuncI32<>,
672 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100673}
674
675bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
676 const TensorInfo& output,
677 const PadDescriptor& descriptor,
678 Optional<std::string&> reasonIfUnsupported) const
679{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100680 ignore_unused(output);
681 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000682 return IsSupportedForDataTypeRef(reasonIfUnsupported,
683 input.GetDataType(),
684 &TrueFunc<>,
685 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100686}
687
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100688bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
689 const TensorInfo& output,
690 const PermuteDescriptor& descriptor,
691 Optional<std::string&> reasonIfUnsupported) const
692{
693 ignore_unused(output);
694 ignore_unused(descriptor);
695 return IsSupportedForDataTypeRef(reasonIfUnsupported,
696 input.GetDataType(),
697 &TrueFunc<>,
698 &TrueFunc<>);
699}
700
701bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
702 const TensorInfo& output,
703 const Pooling2dDescriptor& descriptor,
704 Optional<std::string&> reasonIfUnsupported) const
705{
706 ignore_unused(output);
707 ignore_unused(descriptor);
708 return IsSupportedForDataTypeRef(reasonIfUnsupported,
709 input.GetDataType(),
710 &TrueFunc<>,
711 &TrueFunc<>);
712}
713
714bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000715 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100716 Optional<std::string&> reasonIfUnsupported) const
717{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000718 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100719 return IsSupportedForDataTypeRef(reasonIfUnsupported,
720 input.GetDataType(),
721 &TrueFunc<>,
722 &TrueFunc<>);
723}
724
725bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000726 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100727 Optional<std::string&> reasonIfUnsupported) const
728{
Sadik Armaganc625f002018-12-17 11:32:16 +0000729 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100730 return IsSupportedForDataTypeRef(reasonIfUnsupported,
731 input.GetDataType(),
732 &TrueFunc<>,
733 &TrueFunc<>);
734}
735
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000736bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
737 const TensorInfo& output,
738 Optional<std::string&> reasonIfUnsupported) const
739{
740 ignore_unused(output);
741 return IsSupportedForDataTypeRef(reasonIfUnsupported,
742 input.GetDataType(),
743 &TrueFunc<>,
744 &FalseFuncU8<>);
745}
746
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100747bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
748 const TensorInfo& output,
749 const SoftmaxDescriptor& descriptor,
750 Optional<std::string&> reasonIfUnsupported) const
751{
752 ignore_unused(output);
753 ignore_unused(descriptor);
754 return IsSupportedForDataTypeRef(reasonIfUnsupported,
755 input.GetDataType(),
756 &TrueFunc<>,
757 &TrueFunc<>);
758}
759
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000760bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
761 const TensorInfo& output,
762 const SpaceToBatchNdDescriptor& descriptor,
763 Optional<std::string&> reasonIfUnsupported) const
764{
765 ignore_unused(output);
766 ignore_unused(descriptor);
767 return IsSupportedForDataTypeRef(reasonIfUnsupported,
768 input.GetDataType(),
769 &TrueFunc<>,
770 &TrueFunc<>);
771}
772
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100773bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
774 const ViewsDescriptor& descriptor,
775 Optional<std::string&> reasonIfUnsupported) const
776{
777 ignore_unused(descriptor);
778 return IsSupportedForDataTypeRef(reasonIfUnsupported,
779 input.GetDataType(),
780 &TrueFunc<>,
781 &TrueFunc<>);
782}
783
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000784bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
785 const TensorInfo& output,
786 const StridedSliceDescriptor& descriptor,
787 Optional<std::string&> reasonIfUnsupported) const
788{
789 ignore_unused(output);
790 ignore_unused(descriptor);
791 return IsSupportedForDataTypeRef(reasonIfUnsupported,
792 input.GetDataType(),
793 &TrueFunc<>,
794 &TrueFunc<>);
795}
796
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100797bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
798 const TensorInfo& input1,
799 const TensorInfo& output,
800 Optional<std::string&> reasonIfUnsupported) const
801{
802 ignore_unused(input1);
803 ignore_unused(output);
804 return IsSupportedForDataTypeRef(reasonIfUnsupported,
805 input0.GetDataType(),
806 &TrueFunc<>,
807 &TrueFunc<>);
808}
809
arovir011c7c81b2018-10-08 11:34:28 +0100810} // namespace armnn