blob: 532c8eaf987172577bdc711c2488d4eff1c2912b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010015
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
19#include <algorithm>
20#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
Derek Lamberti50db4e82019-03-13 14:16:15 +000049
50namespace
51{
52template<typename F>
53bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
54{
55 bool supported = rule();
56 if (!supported && reason)
57 {
58 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
59 }
60 return supported;
61}
62
63struct Rule
64{
65 bool operator()() const
66 {
67 return m_Res;
68 }
69
70 bool m_Res = true;
71};
72
Derek Lamberti2a434a82019-03-20 13:07:57 +000073template<typename T>
74bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000075{
76 return true;
77}
78
79template<typename T, typename... Rest>
80bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
81{
82 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
83
Derek Lamberti2a434a82019-03-20 13:07:57 +000084 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000085}
86
87struct TypesAreEqual : public Rule
88{
89 template<typename ... Ts>
90 TypesAreEqual(const Ts&... ts)
91 {
92 m_Res = AllTypesAreEqualImpl(ts...);
93 }
94};
95
96struct QuantizationParametersAreEqual : public Rule
97{
98 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
99 {
100 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
101 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
102 }
103};
104
105struct TypeAnyOf : public Rule
106{
107 template<typename Container>
108 TypeAnyOf(const TensorInfo& info, const Container& c)
109 {
110 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
111 {
112 return dt == info.GetDataType();
113 });
114 }
115};
116
117struct ShapesAreSameRank : public Rule
118{
119 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
120 {
121 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
122 }
123};
124
125struct ShapesAreBroadcastCompatible : public Rule
126{
127 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
128 {
129 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
130 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
131 return sizeIn;
132 }
133
134 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
135 {
136 const TensorShape& shape0 = in0.GetShape();
137 const TensorShape& shape1 = in1.GetShape();
138 const TensorShape& outShape = out.GetShape();
139
140 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
141 {
142 unsigned int sizeOut = outShape[i];
143 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
144 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
145
146 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
147 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
148 }
149 }
150};
151} // namespace
152
153
arovir011c7c81b2018-10-08 11:34:28 +0100154bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
155 const TensorInfo& output,
156 const ActivationDescriptor& descriptor,
157 Optional<std::string&> reasonIfUnsupported) const
158{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000159 bool supported = true;
160
161 // Define supported types.
162 std::array<DataType,2> supportedTypes = {
163 DataType::Float32,
164 DataType::QuantisedAsymm8
165 };
166
167 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
168 "Reference activation: input type not supported.");
169
170 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
171 "Reference activation: output type not supported.");
172
173 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
174 "Reference activation: input and output types mismatched.");
175
176 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
177 "Reference activation: input and output shapes are of different rank.");
178
179
180 struct ActivationFunctionSupported : public Rule
181 {
182 ActivationFunctionSupported(const ActivationDescriptor& desc)
183 {
184 switch(desc.m_Function)
185 {
186 case ActivationFunction::Abs:
187 case ActivationFunction::BoundedReLu:
188 case ActivationFunction::LeakyReLu:
189 case ActivationFunction::Linear:
190 case ActivationFunction::ReLu:
191 case ActivationFunction::Sigmoid:
192 case ActivationFunction::SoftReLu:
193 case ActivationFunction::Sqrt:
194 case ActivationFunction::Square:
195 case ActivationFunction::TanH:
196 {
197 m_Res = true;
198 break;
199 }
200 default:
201 {
202 m_Res = false;
203 break;
204 }
205 }
206 }
207 };
208
209 // Function is supported
210 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
211 "Reference activation: function not supported.");
212
213 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100214}
215
216bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
217 const TensorInfo& input1,
218 const TensorInfo& output,
219 Optional<std::string&> reasonIfUnsupported) const
220{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000221 bool supported = true;
222
223 std::array<DataType,2> supportedTypes = {
224 DataType::Float32,
225 DataType::QuantisedAsymm8
226 };
227
228 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
229 "Reference addition: input 0 is not a supported type.");
230
231 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
232 "Reference addition: input 1 is not a supported type.");
233
234 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
235 "Reference addition: output is not a supported type.");
236
237 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
238 "Reference addition: input 0 and Input 1 types are mismatched");
239
240 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
241 "Reference addition: input and output types are mismatched");
242
243 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
244 "Reference addition: shapes are not suitable for implicit broadcast.");
245
246 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100247}
248
249bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
250 const TensorInfo& output,
251 const TensorInfo& mean,
252 const TensorInfo& var,
253 const TensorInfo& beta,
254 const TensorInfo& gamma,
255 const BatchNormalizationDescriptor& descriptor,
256 Optional<std::string&> reasonIfUnsupported) const
257{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100258 ignore_unused(output);
259 ignore_unused(mean);
260 ignore_unused(var);
261 ignore_unused(beta);
262 ignore_unused(gamma);
263 ignore_unused(descriptor);
264 return IsSupportedForDataTypeRef(reasonIfUnsupported,
265 input.GetDataType(),
266 &TrueFunc<>,
267 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100268}
269
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000270bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
271 const TensorInfo& output,
272 const BatchToSpaceNdDescriptor& descriptor,
273 Optional<std::string&> reasonIfUnsupported) const
274{
275 ignore_unused(descriptor);
276 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
277 input.GetDataType(),
278 &TrueFunc<>,
279 &TrueFunc<>) &&
280 IsSupportedForDataTypeRef(reasonIfUnsupported,
281 output.GetDataType(),
282 &TrueFunc<>,
283 &TrueFunc<>));
284}
285
arovir011c7c81b2018-10-08 11:34:28 +0100286bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
287 Optional<std::string&> reasonIfUnsupported) const
288{
narpra01db2b1602019-01-23 15:23:11 +0000289 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
290 output.GetDataType(),
291 &FalseFunc<>,
292 &TrueFunc<>,
293 &TrueFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000294 &TrueFunc<>,
295 &FalseFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100296}
297
298bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
299 const TensorInfo& output,
300 Optional<std::string&> reasonIfUnsupported) const
301{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100302 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
303 input.GetDataType(),
304 &TrueFunc<>,
305 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000306 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000307 &FalseFuncI32<>,
308 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100309 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
310 output.GetDataType(),
311 &FalseOutputFuncF16<>,
312 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000313 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000314 &FalseFuncI32<>,
315 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100316}
317
318bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
319 const TensorInfo& output,
320 Optional<std::string&> reasonIfUnsupported) const
321{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100322 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
323 input.GetDataType(),
324 &FalseInputFuncF16<>,
325 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000326 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000327 &FalseFuncI32<>,
328 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100329 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
330 output.GetDataType(),
331 &TrueFunc<>,
332 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000333 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000334 &FalseFuncI32<>,
335 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100336}
337
338bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
339 const TensorInfo& output,
340 const Convolution2dDescriptor& descriptor,
341 const TensorInfo& weights,
342 const Optional<TensorInfo>& biases,
343 Optional<std::string&> reasonIfUnsupported) const
344{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100345 ignore_unused(output);
346 ignore_unused(descriptor);
347 ignore_unused(weights);
348 ignore_unused(biases);
349 return IsSupportedForDataTypeRef(reasonIfUnsupported,
350 input.GetDataType(),
351 &TrueFunc<>,
352 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100353}
354
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000355bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
356 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000357 Optional<std::string&> reasonIfUnsupported) const
358{
359 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000360 return IsSupportedForDataTypeRef(reasonIfUnsupported,
361 input.GetDataType(),
362 &TrueFunc<>,
363 &TrueFunc<>);
364}
365
arovir011c7c81b2018-10-08 11:34:28 +0100366bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
367 const TensorInfo& output,
368 const DepthwiseConvolution2dDescriptor& descriptor,
369 const TensorInfo& weights,
370 const Optional<TensorInfo>& biases,
371 Optional<std::string&> reasonIfUnsupported) const
372{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100373 ignore_unused(output);
374 ignore_unused(descriptor);
375 ignore_unused(weights);
376 ignore_unused(biases);
377 return IsSupportedForDataTypeRef(reasonIfUnsupported,
378 input.GetDataType(),
379 &TrueFunc<>,
380 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100381}
382
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000383bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
384 const TensorInfo& output,
385 Optional<std::string&> reasonIfUnsupported) const
386{
387 return IsSupportedForDataTypeRef(reasonIfUnsupported,
388 input.GetDataType(),
389 &FalseFunc<>,
390 &TrueFunc<>);
391}
392
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000393bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
394 const armnn::TensorInfo& input1,
395 const armnn::DetectionPostProcessDescriptor& descriptor,
396 armnn::Optional<std::string&> reasonIfUnsupported) const
397{
398 ignore_unused(input1);
399 return IsSupportedForDataTypeRef(reasonIfUnsupported,
400 input0.GetDataType(),
401 &TrueFunc<>,
402 &TrueFunc<>);
403}
404
arovir011c7c81b2018-10-08 11:34:28 +0100405bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
406 const TensorInfo& input1,
407 const TensorInfo& output,
408 Optional<std::string&> reasonIfUnsupported) const
409{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100410 ignore_unused(input1);
411 ignore_unused(output);
412 return IsSupportedForDataTypeRef(reasonIfUnsupported,
413 input0.GetDataType(),
414 &TrueFunc<>,
415 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100416}
417
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000418bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
419 const TensorInfo& input1,
420 const TensorInfo& output,
421 Optional<std::string&> reasonIfUnsupported) const
422{
423 ignore_unused(input0);
424 ignore_unused(input1);
425 ignore_unused(output);
426 ignore_unused(reasonIfUnsupported);
427 return IsSupportedForDataTypeRef(reasonIfUnsupported,
428 input0.GetDataType(),
429 &TrueFunc<>,
430 &TrueFunc<>);
431}
432
arovir011c7c81b2018-10-08 11:34:28 +0100433bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
434 const FakeQuantizationDescriptor& descriptor,
435 Optional<std::string&> reasonIfUnsupported) const
436{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100437 ignore_unused(descriptor);
438 return IsSupportedForDataTypeRef(reasonIfUnsupported,
439 input.GetDataType(),
440 &TrueFunc<>,
441 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100442}
443
444bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
445 const TensorInfo& output,
446 Optional<std::string&> reasonIfUnsupported) const
447{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100448 ignore_unused(output);
449 return IsSupportedForDataTypeRef(reasonIfUnsupported,
450 input.GetDataType(),
451 &TrueFunc<>,
452 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100453}
454
455bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
456 const TensorInfo& output,
457 const TensorInfo& weights,
458 const TensorInfo& biases,
459 const FullyConnectedDescriptor& descriptor,
460 Optional<std::string&> reasonIfUnsupported) const
461{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100462 ignore_unused(output);
463 ignore_unused(weights);
464 ignore_unused(biases);
465 ignore_unused(descriptor);
466 return IsSupportedForDataTypeRef(reasonIfUnsupported,
467 input.GetDataType(),
468 &TrueFunc<>,
469 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100470}
471
narpra014951d842019-01-18 16:53:53 +0000472bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
473 const armnn::TensorInfo& input1,
474 const armnn::TensorInfo& output,
475 armnn::Optional<std::string&> reasonIfUnsupported) const
476{
477 ignore_unused(input1);
478 ignore_unused(output);
479 return IsSupportedForDataTypeRef(reasonIfUnsupported,
480 input0.GetDataType(),
481 &TrueFunc<>,
482 &TrueFunc<>);
483}
484
FrancisMurtagh878f0232018-12-19 10:56:15 +0000485bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
486 const TensorInfo& input1,
487 const TensorInfo& output,
488 Optional<std::string&> reasonIfUnsupported) const
489{
490 ignore_unused(input0);
491 ignore_unused(input1);
492 ignore_unused(output);
493 ignore_unused(reasonIfUnsupported);
494 return IsSupportedForDataTypeRef(reasonIfUnsupported,
495 input0.GetDataType(),
496 &TrueFunc<>,
497 &TrueFunc<>);
498}
499
arovir011c7c81b2018-10-08 11:34:28 +0100500bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
501 Optional<std::string&> reasonIfUnsupported) const
502{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100503 return IsSupportedForDataTypeRef(reasonIfUnsupported,
504 input.GetDataType(),
505 &TrueFunc<>,
506 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100507}
508
509bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
510 const TensorInfo& output,
511 const L2NormalizationDescriptor& descriptor,
512 Optional<std::string&> reasonIfUnsupported) const
513{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100514 ignore_unused(output);
515 ignore_unused(descriptor);
516 return IsSupportedForDataTypeRef(reasonIfUnsupported,
517 input.GetDataType(),
518 &TrueFunc<>,
519 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100520}
521
522bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
523 const TensorInfo& outputStateIn,
524 const TensorInfo& cellStateIn,
525 const TensorInfo& scratchBuffer,
526 const TensorInfo& outputStateOut,
527 const TensorInfo& cellStateOut,
528 const TensorInfo& output,
529 const LstmDescriptor& descriptor,
530 const TensorInfo& inputToForgetWeights,
531 const TensorInfo& inputToCellWeights,
532 const TensorInfo& inputToOutputWeights,
533 const TensorInfo& recurrentToForgetWeights,
534 const TensorInfo& recurrentToCellWeights,
535 const TensorInfo& recurrentToOutputWeights,
536 const TensorInfo& forgetGateBias,
537 const TensorInfo& cellBias,
538 const TensorInfo& outputGateBias,
539 const TensorInfo* inputToInputWeights,
540 const TensorInfo* recurrentToInputWeights,
541 const TensorInfo* cellToInputWeights,
542 const TensorInfo* inputGateBias,
543 const TensorInfo* projectionWeights,
544 const TensorInfo* projectionBias,
545 const TensorInfo* cellToForgetWeights,
546 const TensorInfo* cellToOutputWeights,
547 Optional<std::string&> reasonIfUnsupported) const
548{
telsoa01c577f2c2018-08-31 09:22:23 +0100549 ignore_unused(outputStateIn);
550 ignore_unused(cellStateIn);
551 ignore_unused(scratchBuffer);
552 ignore_unused(outputStateOut);
553 ignore_unused(cellStateOut);
554 ignore_unused(output);
555 ignore_unused(descriptor);
556 ignore_unused(inputToForgetWeights);
557 ignore_unused(inputToCellWeights);
558 ignore_unused(inputToOutputWeights);
559 ignore_unused(recurrentToForgetWeights);
560 ignore_unused(recurrentToCellWeights);
561 ignore_unused(recurrentToOutputWeights);
562 ignore_unused(forgetGateBias);
563 ignore_unused(cellBias);
564 ignore_unused(outputGateBias);
565 ignore_unused(inputToInputWeights);
566 ignore_unused(recurrentToInputWeights);
567 ignore_unused(cellToInputWeights);
568 ignore_unused(inputGateBias);
569 ignore_unused(projectionWeights);
570 ignore_unused(projectionBias);
571 ignore_unused(cellToForgetWeights);
572 ignore_unused(cellToOutputWeights);
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000573 return IsSupportedForDataTypeRef(reasonIfUnsupported,
574 input.GetDataType(),
575 &TrueFunc<>,
576 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100577}
578
saoste012df12b32018-11-28 16:57:20 +0000579bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
580 const TensorInfo& input1,
581 const TensorInfo& output,
582 Optional<std::string&> reasonIfUnsupported) const
583{
584 ignore_unused(input1);
585 ignore_unused(output);
586 return IsSupportedForDataTypeRef(reasonIfUnsupported,
587 input0.GetDataType(),
588 &TrueFunc<>,
589 &TrueFunc<>);
590}
591
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100592bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
593 const TensorInfo& output,
594 const MeanDescriptor& descriptor,
595 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100596{
narpra011e4c31d2018-09-28 11:07:51 +0100597 ignore_unused(output);
598 ignore_unused(descriptor);
599 return IsSupportedForDataTypeRef(reasonIfUnsupported,
600 input.GetDataType(),
601 &TrueFunc<>,
602 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100603}
604
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100605bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000606 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100607 const OriginsDescriptor& descriptor,
608 Optional<std::string&> reasonIfUnsupported) const
609{
610 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000611 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100612 return IsSupportedForDataTypeRef(reasonIfUnsupported,
613 inputs[0]->GetDataType(),
614 &TrueFunc<>,
615 &TrueFunc<>);
616}
617
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000618bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
619 const TensorInfo &output,
620 Optional<std::string &> reasonIfUnsupported) const
621{
622 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000623 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
624 input.GetDataType(),
625 &TrueFunc<>,
626 &TrueFunc<>,
627 &TrueFunc<>,
628 &FalseFuncI32<>,
629 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000630}
631
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000632bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
633 const TensorInfo& input1,
634 const TensorInfo& output,
635 Optional<std::string&> reasonIfUnsupported) const
636{
637 ignore_unused(input1);
638 ignore_unused(output);
639 return IsSupportedForDataTypeRef(reasonIfUnsupported,
640 input0.GetDataType(),
641 &TrueFunc<>,
642 &TrueFunc<>);
643}
644
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100645bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
646 const TensorInfo& input1,
647 const TensorInfo& output,
648 Optional<std::string&> reasonIfUnsupported) const
649{
650 ignore_unused(input1);
651 ignore_unused(output);
652 return IsSupportedForDataTypeRef(reasonIfUnsupported,
653 input0.GetDataType(),
654 &TrueFunc<>,
655 &TrueFunc<>);
656}
657
658bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
659 const TensorInfo& output,
660 const NormalizationDescriptor& descriptor,
661 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100662{
663 ignore_unused(output);
664 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100665 return IsSupportedForDataTypeRef(reasonIfUnsupported,
666 input.GetDataType(),
667 &TrueFunc<>,
668 &FalseFuncU8<>);
669}
670
671bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
672 Optional<std::string&> reasonIfUnsupported) const
673{
kevmay012b4d88e2019-01-24 14:05:09 +0000674 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
675 output.GetDataType(),
676 &TrueFunc<>,
677 &TrueFunc<>,
678 &TrueFunc<>,
679 &FalseFuncI32<>,
680 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100681}
682
683bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
684 const TensorInfo& output,
685 const PadDescriptor& descriptor,
686 Optional<std::string&> reasonIfUnsupported) const
687{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100688 ignore_unused(output);
689 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000690 return IsSupportedForDataTypeRef(reasonIfUnsupported,
691 input.GetDataType(),
692 &TrueFunc<>,
693 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100694}
695
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100696bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
697 const TensorInfo& output,
698 const PermuteDescriptor& descriptor,
699 Optional<std::string&> reasonIfUnsupported) const
700{
701 ignore_unused(output);
702 ignore_unused(descriptor);
703 return IsSupportedForDataTypeRef(reasonIfUnsupported,
704 input.GetDataType(),
705 &TrueFunc<>,
706 &TrueFunc<>);
707}
708
709bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
710 const TensorInfo& output,
711 const Pooling2dDescriptor& descriptor,
712 Optional<std::string&> reasonIfUnsupported) const
713{
714 ignore_unused(output);
715 ignore_unused(descriptor);
716 return IsSupportedForDataTypeRef(reasonIfUnsupported,
717 input.GetDataType(),
718 &TrueFunc<>,
719 &TrueFunc<>);
720}
721
722bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000723 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100724 Optional<std::string&> reasonIfUnsupported) const
725{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000726 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100727 return IsSupportedForDataTypeRef(reasonIfUnsupported,
728 input.GetDataType(),
729 &TrueFunc<>,
730 &TrueFunc<>);
731}
732
733bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000734 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100735 Optional<std::string&> reasonIfUnsupported) const
736{
Sadik Armaganc625f002018-12-17 11:32:16 +0000737 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100738 return IsSupportedForDataTypeRef(reasonIfUnsupported,
739 input.GetDataType(),
740 &TrueFunc<>,
741 &TrueFunc<>);
742}
743
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000744bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
745 const TensorInfo& output,
746 Optional<std::string&> reasonIfUnsupported) const
747{
748 ignore_unused(output);
749 return IsSupportedForDataTypeRef(reasonIfUnsupported,
750 input.GetDataType(),
751 &TrueFunc<>,
752 &FalseFuncU8<>);
753}
754
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100755bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
756 const TensorInfo& output,
757 const SoftmaxDescriptor& descriptor,
758 Optional<std::string&> reasonIfUnsupported) const
759{
760 ignore_unused(output);
761 ignore_unused(descriptor);
762 return IsSupportedForDataTypeRef(reasonIfUnsupported,
763 input.GetDataType(),
764 &TrueFunc<>,
765 &TrueFunc<>);
766}
767
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000768bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
769 const TensorInfo& output,
770 const SpaceToBatchNdDescriptor& descriptor,
771 Optional<std::string&> reasonIfUnsupported) const
772{
773 ignore_unused(output);
774 ignore_unused(descriptor);
775 return IsSupportedForDataTypeRef(reasonIfUnsupported,
776 input.GetDataType(),
777 &TrueFunc<>,
778 &TrueFunc<>);
779}
780
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100781bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
782 const ViewsDescriptor& descriptor,
783 Optional<std::string&> reasonIfUnsupported) const
784{
785 ignore_unused(descriptor);
786 return IsSupportedForDataTypeRef(reasonIfUnsupported,
787 input.GetDataType(),
788 &TrueFunc<>,
789 &TrueFunc<>);
790}
791
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000792bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
793 const TensorInfo& output,
794 const StridedSliceDescriptor& descriptor,
795 Optional<std::string&> reasonIfUnsupported) const
796{
797 ignore_unused(output);
798 ignore_unused(descriptor);
799 return IsSupportedForDataTypeRef(reasonIfUnsupported,
800 input.GetDataType(),
801 &TrueFunc<>,
802 &TrueFunc<>);
803}
804
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100805bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
806 const TensorInfo& input1,
807 const TensorInfo& output,
808 Optional<std::string&> reasonIfUnsupported) const
809{
810 ignore_unused(input1);
811 ignore_unused(output);
812 return IsSupportedForDataTypeRef(reasonIfUnsupported,
813 input0.GetDataType(),
814 &TrueFunc<>,
815 &TrueFunc<>);
816}
817
arovir011c7c81b2018-10-08 11:34:28 +0100818} // namespace armnn