blob: 820f36b7acc949ca8129cc52030c3feb79c7ae01 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010015
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <vector>
19#include <algorithm>
20#include <array>
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace boost;
23
24namespace armnn
25{
26
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010027namespace
28{
29
30template<typename Float32Func, typename Uint8Func, typename ... Params>
31bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
32 DataType dataType,
33 Float32Func floatFuncPtr,
34 Uint8Func uint8FuncPtr,
35 Params&&... params)
36{
37 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
38 dataType,
39 &FalseFunc<Params...>,
40 floatFuncPtr,
41 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000042 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000043 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 std::forward<Params>(params)...);
45}
46
47} // anonymous namespace
48
Derek Lamberti50db4e82019-03-13 14:16:15 +000049
50namespace
51{
52template<typename F>
53bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
54{
55 bool supported = rule();
56 if (!supported && reason)
57 {
58 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
59 }
60 return supported;
61}
62
63struct Rule
64{
65 bool operator()() const
66 {
67 return m_Res;
68 }
69
70 bool m_Res = true;
71};
72
Derek Lamberti2a434a82019-03-20 13:07:57 +000073template<typename T>
74bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000075{
76 return true;
77}
78
79template<typename T, typename... Rest>
80bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
81{
82 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
83
Derek Lamberti2a434a82019-03-20 13:07:57 +000084 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000085}
86
87struct TypesAreEqual : public Rule
88{
89 template<typename ... Ts>
90 TypesAreEqual(const Ts&... ts)
91 {
92 m_Res = AllTypesAreEqualImpl(ts...);
93 }
94};
95
96struct QuantizationParametersAreEqual : public Rule
97{
98 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
99 {
100 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
101 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
102 }
103};
104
105struct TypeAnyOf : public Rule
106{
107 template<typename Container>
108 TypeAnyOf(const TensorInfo& info, const Container& c)
109 {
110 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
111 {
112 return dt == info.GetDataType();
113 });
114 }
115};
116
117struct ShapesAreSameRank : public Rule
118{
119 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
120 {
121 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
122 }
123};
124
125struct ShapesAreBroadcastCompatible : public Rule
126{
127 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
128 {
129 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
130 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
131 return sizeIn;
132 }
133
134 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
135 {
136 const TensorShape& shape0 = in0.GetShape();
137 const TensorShape& shape1 = in1.GetShape();
138 const TensorShape& outShape = out.GetShape();
139
140 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
141 {
142 unsigned int sizeOut = outShape[i];
143 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
144 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
145
146 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
147 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
148 }
149 }
150};
151} // namespace
152
153
arovir011c7c81b2018-10-08 11:34:28 +0100154bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
155 const TensorInfo& output,
156 const ActivationDescriptor& descriptor,
157 Optional<std::string&> reasonIfUnsupported) const
158{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000159 bool supported = true;
160
161 // Define supported types.
162 std::array<DataType,2> supportedTypes = {
163 DataType::Float32,
164 DataType::QuantisedAsymm8
165 };
166
167 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
168 "Reference activation: input type not supported.");
169
170 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
171 "Reference activation: output type not supported.");
172
173 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
174 "Reference activation: input and output types mismatched.");
175
176 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
177 "Reference activation: input and output shapes are of different rank.");
178
179
180 struct ActivationFunctionSupported : public Rule
181 {
182 ActivationFunctionSupported(const ActivationDescriptor& desc)
183 {
184 switch(desc.m_Function)
185 {
186 case ActivationFunction::Abs:
187 case ActivationFunction::BoundedReLu:
188 case ActivationFunction::LeakyReLu:
189 case ActivationFunction::Linear:
190 case ActivationFunction::ReLu:
191 case ActivationFunction::Sigmoid:
192 case ActivationFunction::SoftReLu:
193 case ActivationFunction::Sqrt:
194 case ActivationFunction::Square:
195 case ActivationFunction::TanH:
196 {
197 m_Res = true;
198 break;
199 }
200 default:
201 {
202 m_Res = false;
203 break;
204 }
205 }
206 }
207 };
208
209 // Function is supported
210 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
211 "Reference activation: function not supported.");
212
213 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100214}
215
216bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
217 const TensorInfo& input1,
218 const TensorInfo& output,
219 Optional<std::string&> reasonIfUnsupported) const
220{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000221 bool supported = true;
222
223 std::array<DataType,2> supportedTypes = {
224 DataType::Float32,
225 DataType::QuantisedAsymm8
226 };
227
228 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
229 "Reference addition: input 0 is not a supported type.");
230
231 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
232 "Reference addition: input 1 is not a supported type.");
233
234 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
235 "Reference addition: output is not a supported type.");
236
237 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
238 "Reference addition: input 0 and Input 1 types are mismatched");
239
240 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
241 "Reference addition: input and output types are mismatched");
242
243 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
244 "Reference addition: shapes are not suitable for implicit broadcast.");
245
246 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100247}
248
249bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
250 const TensorInfo& output,
251 const TensorInfo& mean,
252 const TensorInfo& var,
253 const TensorInfo& beta,
254 const TensorInfo& gamma,
255 const BatchNormalizationDescriptor& descriptor,
256 Optional<std::string&> reasonIfUnsupported) const
257{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100258 ignore_unused(output);
259 ignore_unused(mean);
260 ignore_unused(var);
261 ignore_unused(beta);
262 ignore_unused(gamma);
263 ignore_unused(descriptor);
264 return IsSupportedForDataTypeRef(reasonIfUnsupported,
265 input.GetDataType(),
266 &TrueFunc<>,
267 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100268}
269
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000270bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
271 const TensorInfo& output,
272 const BatchToSpaceNdDescriptor& descriptor,
273 Optional<std::string&> reasonIfUnsupported) const
274{
275 ignore_unused(descriptor);
276 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
277 input.GetDataType(),
278 &TrueFunc<>,
279 &TrueFunc<>) &&
280 IsSupportedForDataTypeRef(reasonIfUnsupported,
281 output.GetDataType(),
282 &TrueFunc<>,
283 &TrueFunc<>));
284}
285
arovir011c7c81b2018-10-08 11:34:28 +0100286bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
287 Optional<std::string&> reasonIfUnsupported) const
288{
narpra01db2b1602019-01-23 15:23:11 +0000289 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
290 output.GetDataType(),
291 &FalseFunc<>,
292 &TrueFunc<>,
293 &TrueFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000294 &TrueFunc<>,
295 &FalseFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100296}
297
298bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
299 const TensorInfo& output,
300 Optional<std::string&> reasonIfUnsupported) const
301{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100302 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
303 input.GetDataType(),
304 &TrueFunc<>,
305 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000306 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000307 &FalseFuncI32<>,
308 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100309 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
310 output.GetDataType(),
311 &FalseOutputFuncF16<>,
312 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000313 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000314 &FalseFuncI32<>,
315 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100316}
317
318bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
319 const TensorInfo& output,
320 Optional<std::string&> reasonIfUnsupported) const
321{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100322 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
323 input.GetDataType(),
324 &FalseInputFuncF16<>,
325 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000326 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000327 &FalseFuncI32<>,
328 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100329 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
330 output.GetDataType(),
331 &TrueFunc<>,
332 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000333 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000334 &FalseFuncI32<>,
335 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100336}
337
338bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
339 const TensorInfo& output,
340 const Convolution2dDescriptor& descriptor,
341 const TensorInfo& weights,
342 const Optional<TensorInfo>& biases,
343 Optional<std::string&> reasonIfUnsupported) const
344{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100345 ignore_unused(output);
346 ignore_unused(descriptor);
347 ignore_unused(weights);
348 ignore_unused(biases);
349 return IsSupportedForDataTypeRef(reasonIfUnsupported,
350 input.GetDataType(),
351 &TrueFunc<>,
352 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100353}
354
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000355bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
356 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000357 Optional<std::string&> reasonIfUnsupported) const
358{
359 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000360 return IsSupportedForDataTypeRef(reasonIfUnsupported,
361 input.GetDataType(),
362 &TrueFunc<>,
363 &TrueFunc<>);
364}
365
arovir011c7c81b2018-10-08 11:34:28 +0100366bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
367 const TensorInfo& output,
368 const DepthwiseConvolution2dDescriptor& descriptor,
369 const TensorInfo& weights,
370 const Optional<TensorInfo>& biases,
371 Optional<std::string&> reasonIfUnsupported) const
372{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100373 ignore_unused(output);
374 ignore_unused(descriptor);
375 ignore_unused(weights);
376 ignore_unused(biases);
377 return IsSupportedForDataTypeRef(reasonIfUnsupported,
378 input.GetDataType(),
379 &TrueFunc<>,
380 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100381}
382
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000383bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
384 const armnn::TensorInfo& input1,
385 const armnn::DetectionPostProcessDescriptor& descriptor,
386 armnn::Optional<std::string&> reasonIfUnsupported) const
387{
388 ignore_unused(input1);
389 return IsSupportedForDataTypeRef(reasonIfUnsupported,
390 input0.GetDataType(),
391 &TrueFunc<>,
392 &TrueFunc<>);
393}
394
arovir011c7c81b2018-10-08 11:34:28 +0100395bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
396 const TensorInfo& input1,
397 const TensorInfo& output,
398 Optional<std::string&> reasonIfUnsupported) const
399{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100400 ignore_unused(input1);
401 ignore_unused(output);
402 return IsSupportedForDataTypeRef(reasonIfUnsupported,
403 input0.GetDataType(),
404 &TrueFunc<>,
405 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100406}
407
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000408bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
409 const TensorInfo& input1,
410 const TensorInfo& output,
411 Optional<std::string&> reasonIfUnsupported) const
412{
413 ignore_unused(input0);
414 ignore_unused(input1);
415 ignore_unused(output);
416 ignore_unused(reasonIfUnsupported);
417 return IsSupportedForDataTypeRef(reasonIfUnsupported,
418 input0.GetDataType(),
419 &TrueFunc<>,
420 &TrueFunc<>);
421}
422
arovir011c7c81b2018-10-08 11:34:28 +0100423bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
424 const FakeQuantizationDescriptor& descriptor,
425 Optional<std::string&> reasonIfUnsupported) const
426{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100427 ignore_unused(descriptor);
428 return IsSupportedForDataTypeRef(reasonIfUnsupported,
429 input.GetDataType(),
430 &TrueFunc<>,
431 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100432}
433
434bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
435 const TensorInfo& output,
436 Optional<std::string&> reasonIfUnsupported) const
437{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100438 ignore_unused(output);
439 return IsSupportedForDataTypeRef(reasonIfUnsupported,
440 input.GetDataType(),
441 &TrueFunc<>,
442 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100443}
444
445bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
446 const TensorInfo& output,
447 const TensorInfo& weights,
448 const TensorInfo& biases,
449 const FullyConnectedDescriptor& descriptor,
450 Optional<std::string&> reasonIfUnsupported) const
451{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100452 ignore_unused(output);
453 ignore_unused(weights);
454 ignore_unused(biases);
455 ignore_unused(descriptor);
456 return IsSupportedForDataTypeRef(reasonIfUnsupported,
457 input.GetDataType(),
458 &TrueFunc<>,
459 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100460}
461
narpra014951d842019-01-18 16:53:53 +0000462bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
463 const armnn::TensorInfo& input1,
464 const armnn::TensorInfo& output,
465 armnn::Optional<std::string&> reasonIfUnsupported) const
466{
467 ignore_unused(input1);
468 ignore_unused(output);
469 return IsSupportedForDataTypeRef(reasonIfUnsupported,
470 input0.GetDataType(),
471 &TrueFunc<>,
472 &TrueFunc<>);
473}
474
FrancisMurtagh878f0232018-12-19 10:56:15 +0000475bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
476 const TensorInfo& input1,
477 const TensorInfo& output,
478 Optional<std::string&> reasonIfUnsupported) const
479{
480 ignore_unused(input0);
481 ignore_unused(input1);
482 ignore_unused(output);
483 ignore_unused(reasonIfUnsupported);
484 return IsSupportedForDataTypeRef(reasonIfUnsupported,
485 input0.GetDataType(),
486 &TrueFunc<>,
487 &TrueFunc<>);
488}
489
arovir011c7c81b2018-10-08 11:34:28 +0100490bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
491 Optional<std::string&> reasonIfUnsupported) const
492{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100493 return IsSupportedForDataTypeRef(reasonIfUnsupported,
494 input.GetDataType(),
495 &TrueFunc<>,
496 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100497}
498
499bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
500 const TensorInfo& output,
501 const L2NormalizationDescriptor& descriptor,
502 Optional<std::string&> reasonIfUnsupported) const
503{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100504 ignore_unused(output);
505 ignore_unused(descriptor);
506 return IsSupportedForDataTypeRef(reasonIfUnsupported,
507 input.GetDataType(),
508 &TrueFunc<>,
509 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100510}
511
512bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
513 const TensorInfo& outputStateIn,
514 const TensorInfo& cellStateIn,
515 const TensorInfo& scratchBuffer,
516 const TensorInfo& outputStateOut,
517 const TensorInfo& cellStateOut,
518 const TensorInfo& output,
519 const LstmDescriptor& descriptor,
520 const TensorInfo& inputToForgetWeights,
521 const TensorInfo& inputToCellWeights,
522 const TensorInfo& inputToOutputWeights,
523 const TensorInfo& recurrentToForgetWeights,
524 const TensorInfo& recurrentToCellWeights,
525 const TensorInfo& recurrentToOutputWeights,
526 const TensorInfo& forgetGateBias,
527 const TensorInfo& cellBias,
528 const TensorInfo& outputGateBias,
529 const TensorInfo* inputToInputWeights,
530 const TensorInfo* recurrentToInputWeights,
531 const TensorInfo* cellToInputWeights,
532 const TensorInfo* inputGateBias,
533 const TensorInfo* projectionWeights,
534 const TensorInfo* projectionBias,
535 const TensorInfo* cellToForgetWeights,
536 const TensorInfo* cellToOutputWeights,
537 Optional<std::string&> reasonIfUnsupported) const
538{
telsoa01c577f2c2018-08-31 09:22:23 +0100539 ignore_unused(outputStateIn);
540 ignore_unused(cellStateIn);
541 ignore_unused(scratchBuffer);
542 ignore_unused(outputStateOut);
543 ignore_unused(cellStateOut);
544 ignore_unused(output);
545 ignore_unused(descriptor);
546 ignore_unused(inputToForgetWeights);
547 ignore_unused(inputToCellWeights);
548 ignore_unused(inputToOutputWeights);
549 ignore_unused(recurrentToForgetWeights);
550 ignore_unused(recurrentToCellWeights);
551 ignore_unused(recurrentToOutputWeights);
552 ignore_unused(forgetGateBias);
553 ignore_unused(cellBias);
554 ignore_unused(outputGateBias);
555 ignore_unused(inputToInputWeights);
556 ignore_unused(recurrentToInputWeights);
557 ignore_unused(cellToInputWeights);
558 ignore_unused(inputGateBias);
559 ignore_unused(projectionWeights);
560 ignore_unused(projectionBias);
561 ignore_unused(cellToForgetWeights);
562 ignore_unused(cellToOutputWeights);
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000563 return IsSupportedForDataTypeRef(reasonIfUnsupported,
564 input.GetDataType(),
565 &TrueFunc<>,
566 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100567}
568
saoste012df12b32018-11-28 16:57:20 +0000569bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
570 const TensorInfo& input1,
571 const TensorInfo& output,
572 Optional<std::string&> reasonIfUnsupported) const
573{
574 ignore_unused(input1);
575 ignore_unused(output);
576 return IsSupportedForDataTypeRef(reasonIfUnsupported,
577 input0.GetDataType(),
578 &TrueFunc<>,
579 &TrueFunc<>);
580}
581
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100582bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
583 const TensorInfo& output,
584 const MeanDescriptor& descriptor,
585 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100586{
narpra011e4c31d2018-09-28 11:07:51 +0100587 ignore_unused(output);
588 ignore_unused(descriptor);
589 return IsSupportedForDataTypeRef(reasonIfUnsupported,
590 input.GetDataType(),
591 &TrueFunc<>,
592 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100593}
594
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100595bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000596 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100597 const OriginsDescriptor& descriptor,
598 Optional<std::string&> reasonIfUnsupported) const
599{
600 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000601 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100602 return IsSupportedForDataTypeRef(reasonIfUnsupported,
603 inputs[0]->GetDataType(),
604 &TrueFunc<>,
605 &TrueFunc<>);
606}
607
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000608bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
609 const TensorInfo &output,
610 Optional<std::string &> reasonIfUnsupported) const
611{
612 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000613 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
614 input.GetDataType(),
615 &TrueFunc<>,
616 &TrueFunc<>,
617 &TrueFunc<>,
618 &FalseFuncI32<>,
619 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000620}
621
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000622bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
623 const TensorInfo& input1,
624 const TensorInfo& output,
625 Optional<std::string&> reasonIfUnsupported) const
626{
627 ignore_unused(input1);
628 ignore_unused(output);
629 return IsSupportedForDataTypeRef(reasonIfUnsupported,
630 input0.GetDataType(),
631 &TrueFunc<>,
632 &TrueFunc<>);
633}
634
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100635bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
636 const TensorInfo& input1,
637 const TensorInfo& output,
638 Optional<std::string&> reasonIfUnsupported) const
639{
640 ignore_unused(input1);
641 ignore_unused(output);
642 return IsSupportedForDataTypeRef(reasonIfUnsupported,
643 input0.GetDataType(),
644 &TrueFunc<>,
645 &TrueFunc<>);
646}
647
648bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
649 const TensorInfo& output,
650 const NormalizationDescriptor& descriptor,
651 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100652{
653 ignore_unused(output);
654 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100655 return IsSupportedForDataTypeRef(reasonIfUnsupported,
656 input.GetDataType(),
657 &TrueFunc<>,
658 &FalseFuncU8<>);
659}
660
661bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
662 Optional<std::string&> reasonIfUnsupported) const
663{
kevmay012b4d88e2019-01-24 14:05:09 +0000664 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
665 output.GetDataType(),
666 &TrueFunc<>,
667 &TrueFunc<>,
668 &TrueFunc<>,
669 &FalseFuncI32<>,
670 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100671}
672
673bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
674 const TensorInfo& output,
675 const PadDescriptor& descriptor,
676 Optional<std::string&> reasonIfUnsupported) const
677{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100678 ignore_unused(output);
679 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000680 return IsSupportedForDataTypeRef(reasonIfUnsupported,
681 input.GetDataType(),
682 &TrueFunc<>,
683 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100684}
685
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100686bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
687 const TensorInfo& output,
688 const PermuteDescriptor& descriptor,
689 Optional<std::string&> reasonIfUnsupported) const
690{
691 ignore_unused(output);
692 ignore_unused(descriptor);
693 return IsSupportedForDataTypeRef(reasonIfUnsupported,
694 input.GetDataType(),
695 &TrueFunc<>,
696 &TrueFunc<>);
697}
698
699bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
700 const TensorInfo& output,
701 const Pooling2dDescriptor& descriptor,
702 Optional<std::string&> reasonIfUnsupported) const
703{
704 ignore_unused(output);
705 ignore_unused(descriptor);
706 return IsSupportedForDataTypeRef(reasonIfUnsupported,
707 input.GetDataType(),
708 &TrueFunc<>,
709 &TrueFunc<>);
710}
711
712bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000713 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100714 Optional<std::string&> reasonIfUnsupported) const
715{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000716 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100717 return IsSupportedForDataTypeRef(reasonIfUnsupported,
718 input.GetDataType(),
719 &TrueFunc<>,
720 &TrueFunc<>);
721}
722
723bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000724 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100725 Optional<std::string&> reasonIfUnsupported) const
726{
Sadik Armaganc625f002018-12-17 11:32:16 +0000727 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100728 return IsSupportedForDataTypeRef(reasonIfUnsupported,
729 input.GetDataType(),
730 &TrueFunc<>,
731 &TrueFunc<>);
732}
733
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000734bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
735 const TensorInfo& output,
736 Optional<std::string&> reasonIfUnsupported) const
737{
738 ignore_unused(output);
739 return IsSupportedForDataTypeRef(reasonIfUnsupported,
740 input.GetDataType(),
741 &TrueFunc<>,
742 &FalseFuncU8<>);
743}
744
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100745bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
746 const TensorInfo& output,
747 const SoftmaxDescriptor& descriptor,
748 Optional<std::string&> reasonIfUnsupported) const
749{
750 ignore_unused(output);
751 ignore_unused(descriptor);
752 return IsSupportedForDataTypeRef(reasonIfUnsupported,
753 input.GetDataType(),
754 &TrueFunc<>,
755 &TrueFunc<>);
756}
757
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000758bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
759 const TensorInfo& output,
760 const SpaceToBatchNdDescriptor& descriptor,
761 Optional<std::string&> reasonIfUnsupported) const
762{
763 ignore_unused(output);
764 ignore_unused(descriptor);
765 return IsSupportedForDataTypeRef(reasonIfUnsupported,
766 input.GetDataType(),
767 &TrueFunc<>,
768 &TrueFunc<>);
769}
770
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100771bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
772 const ViewsDescriptor& descriptor,
773 Optional<std::string&> reasonIfUnsupported) const
774{
775 ignore_unused(descriptor);
776 return IsSupportedForDataTypeRef(reasonIfUnsupported,
777 input.GetDataType(),
778 &TrueFunc<>,
779 &TrueFunc<>);
780}
781
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000782bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
783 const TensorInfo& output,
784 const StridedSliceDescriptor& descriptor,
785 Optional<std::string&> reasonIfUnsupported) const
786{
787 ignore_unused(output);
788 ignore_unused(descriptor);
789 return IsSupportedForDataTypeRef(reasonIfUnsupported,
790 input.GetDataType(),
791 &TrueFunc<>,
792 &TrueFunc<>);
793}
794
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100795bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
796 const TensorInfo& input1,
797 const TensorInfo& output,
798 Optional<std::string&> reasonIfUnsupported) const
799{
800 ignore_unused(input1);
801 ignore_unused(output);
802 return IsSupportedForDataTypeRef(reasonIfUnsupported,
803 input0.GetDataType(),
804 &TrueFunc<>,
805 &TrueFunc<>);
806}
807
arovir011c7c81b2018-10-08 11:34:28 +0100808} // namespace armnn