blob: 2c8f9cb6e131397cdb185c6d6f71fa46c269309a [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
David Beck111b5d92018-11-12 14:59:37 +000013#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010014
telsoa014fcda012018-03-09 14:13:49 +000015#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
17using namespace boost;
18
19namespace armnn
20{
21
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010022namespace
23{
24
25template<typename Float32Func, typename Uint8Func, typename ... Params>
26bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
27 DataType dataType,
28 Float32Func floatFuncPtr,
29 Uint8Func uint8FuncPtr,
30 Params&&... params)
31{
32 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
33 dataType,
34 &FalseFunc<Params...>,
35 floatFuncPtr,
36 uint8FuncPtr,
37 std::forward<Params>(params)...);
38}
39
40} // anonymous namespace
41
arovir011c7c81b2018-10-08 11:34:28 +010042bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
43 const TensorInfo& output,
44 const ActivationDescriptor& descriptor,
45 Optional<std::string&> reasonIfUnsupported) const
46{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010047 ignore_unused(output);
48 ignore_unused(descriptor);
49 return IsSupportedForDataTypeRef(reasonIfUnsupported,
50 input.GetDataType(),
51 &TrueFunc<>,
52 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010053}
54
55bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
56 const TensorInfo& input1,
57 const TensorInfo& output,
58 Optional<std::string&> reasonIfUnsupported) const
59{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010060 ignore_unused(input1);
61 ignore_unused(output);
62 return IsSupportedForDataTypeRef(reasonIfUnsupported,
63 input0.GetDataType(),
64 &TrueFunc<>,
65 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010066}
67
68bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
69 const TensorInfo& output,
70 const TensorInfo& mean,
71 const TensorInfo& var,
72 const TensorInfo& beta,
73 const TensorInfo& gamma,
74 const BatchNormalizationDescriptor& descriptor,
75 Optional<std::string&> reasonIfUnsupported) const
76{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010077 ignore_unused(output);
78 ignore_unused(mean);
79 ignore_unused(var);
80 ignore_unused(beta);
81 ignore_unused(gamma);
82 ignore_unused(descriptor);
83 return IsSupportedForDataTypeRef(reasonIfUnsupported,
84 input.GetDataType(),
85 &TrueFunc<>,
86 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010087}
88
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000089bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
90 const TensorInfo& output,
91 const BatchToSpaceNdDescriptor& descriptor,
92 Optional<std::string&> reasonIfUnsupported) const
93{
94 ignore_unused(descriptor);
95 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
96 input.GetDataType(),
97 &TrueFunc<>,
98 &TrueFunc<>) &&
99 IsSupportedForDataTypeRef(reasonIfUnsupported,
100 output.GetDataType(),
101 &TrueFunc<>,
102 &TrueFunc<>));
103}
104
arovir011c7c81b2018-10-08 11:34:28 +0100105bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
106 Optional<std::string&> reasonIfUnsupported) const
107{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100108 return IsSupportedForDataTypeRef(reasonIfUnsupported,
109 output.GetDataType(),
110 &TrueFunc<>,
111 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100112}
113
114bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
115 const TensorInfo& output,
116 Optional<std::string&> reasonIfUnsupported) const
117{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100118 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
119 input.GetDataType(),
120 &TrueFunc<>,
121 &FalseInputFuncF32<>,
122 &FalseFuncU8<>) &&
123 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
124 output.GetDataType(),
125 &FalseOutputFuncF16<>,
126 &TrueFunc<>,
127 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100128}
129
130bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
131 const TensorInfo& output,
132 Optional<std::string&> reasonIfUnsupported) const
133{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100134 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
135 input.GetDataType(),
136 &FalseInputFuncF16<>,
137 &TrueFunc<>,
138 &FalseFuncU8<>) &&
139 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
140 output.GetDataType(),
141 &TrueFunc<>,
142 &FalseOutputFuncF32<>,
143 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100144}
145
146bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
147 const TensorInfo& output,
148 const Convolution2dDescriptor& descriptor,
149 const TensorInfo& weights,
150 const Optional<TensorInfo>& biases,
151 Optional<std::string&> reasonIfUnsupported) const
152{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100153 ignore_unused(output);
154 ignore_unused(descriptor);
155 ignore_unused(weights);
156 ignore_unused(biases);
157 return IsSupportedForDataTypeRef(reasonIfUnsupported,
158 input.GetDataType(),
159 &TrueFunc<>,
160 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100161}
162
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000163bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
164 const TensorInfo& output,
165 const DebugDescriptor& descriptor,
166 Optional<std::string&> reasonIfUnsupported) const
167{
168 ignore_unused(output);
169 ignore_unused(descriptor);
170 return IsSupportedForDataTypeRef(reasonIfUnsupported,
171 input.GetDataType(),
172 &TrueFunc<>,
173 &TrueFunc<>);
174}
175
arovir011c7c81b2018-10-08 11:34:28 +0100176bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
177 const TensorInfo& output,
178 const DepthwiseConvolution2dDescriptor& descriptor,
179 const TensorInfo& weights,
180 const Optional<TensorInfo>& biases,
181 Optional<std::string&> reasonIfUnsupported) const
182{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100183 ignore_unused(output);
184 ignore_unused(descriptor);
185 ignore_unused(weights);
186 ignore_unused(biases);
187 return IsSupportedForDataTypeRef(reasonIfUnsupported,
188 input.GetDataType(),
189 &TrueFunc<>,
190 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100191}
192
193bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
194 const TensorInfo& input1,
195 const TensorInfo& output,
196 Optional<std::string&> reasonIfUnsupported) const
197{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100198 ignore_unused(input1);
199 ignore_unused(output);
200 return IsSupportedForDataTypeRef(reasonIfUnsupported,
201 input0.GetDataType(),
202 &TrueFunc<>,
203 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100204}
205
206bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
207 const FakeQuantizationDescriptor& descriptor,
208 Optional<std::string&> reasonIfUnsupported) const
209{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100210 ignore_unused(descriptor);
211 return IsSupportedForDataTypeRef(reasonIfUnsupported,
212 input.GetDataType(),
213 &TrueFunc<>,
214 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100215}
216
217bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
218 const TensorInfo& output,
219 Optional<std::string&> reasonIfUnsupported) const
220{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100221 ignore_unused(output);
222 return IsSupportedForDataTypeRef(reasonIfUnsupported,
223 input.GetDataType(),
224 &TrueFunc<>,
225 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100226}
227
228bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
229 const TensorInfo& output,
230 const TensorInfo& weights,
231 const TensorInfo& biases,
232 const FullyConnectedDescriptor& descriptor,
233 Optional<std::string&> reasonIfUnsupported) const
234{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100235 ignore_unused(output);
236 ignore_unused(weights);
237 ignore_unused(biases);
238 ignore_unused(descriptor);
239 return IsSupportedForDataTypeRef(reasonIfUnsupported,
240 input.GetDataType(),
241 &TrueFunc<>,
242 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100243}
244
245bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
246 Optional<std::string&> reasonIfUnsupported) const
247{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100248 return IsSupportedForDataTypeRef(reasonIfUnsupported,
249 input.GetDataType(),
250 &TrueFunc<>,
251 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100252}
253
254bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
255 const TensorInfo& output,
256 const L2NormalizationDescriptor& descriptor,
257 Optional<std::string&> reasonIfUnsupported) const
258{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100259 ignore_unused(output);
260 ignore_unused(descriptor);
261 return IsSupportedForDataTypeRef(reasonIfUnsupported,
262 input.GetDataType(),
263 &TrueFunc<>,
264 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100265}
266
267bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
268 const TensorInfo& outputStateIn,
269 const TensorInfo& cellStateIn,
270 const TensorInfo& scratchBuffer,
271 const TensorInfo& outputStateOut,
272 const TensorInfo& cellStateOut,
273 const TensorInfo& output,
274 const LstmDescriptor& descriptor,
275 const TensorInfo& inputToForgetWeights,
276 const TensorInfo& inputToCellWeights,
277 const TensorInfo& inputToOutputWeights,
278 const TensorInfo& recurrentToForgetWeights,
279 const TensorInfo& recurrentToCellWeights,
280 const TensorInfo& recurrentToOutputWeights,
281 const TensorInfo& forgetGateBias,
282 const TensorInfo& cellBias,
283 const TensorInfo& outputGateBias,
284 const TensorInfo* inputToInputWeights,
285 const TensorInfo* recurrentToInputWeights,
286 const TensorInfo* cellToInputWeights,
287 const TensorInfo* inputGateBias,
288 const TensorInfo* projectionWeights,
289 const TensorInfo* projectionBias,
290 const TensorInfo* cellToForgetWeights,
291 const TensorInfo* cellToOutputWeights,
292 Optional<std::string&> reasonIfUnsupported) const
293{
telsoa01c577f2c2018-08-31 09:22:23 +0100294 ignore_unused(outputStateIn);
295 ignore_unused(cellStateIn);
296 ignore_unused(scratchBuffer);
297 ignore_unused(outputStateOut);
298 ignore_unused(cellStateOut);
299 ignore_unused(output);
300 ignore_unused(descriptor);
301 ignore_unused(inputToForgetWeights);
302 ignore_unused(inputToCellWeights);
303 ignore_unused(inputToOutputWeights);
304 ignore_unused(recurrentToForgetWeights);
305 ignore_unused(recurrentToCellWeights);
306 ignore_unused(recurrentToOutputWeights);
307 ignore_unused(forgetGateBias);
308 ignore_unused(cellBias);
309 ignore_unused(outputGateBias);
310 ignore_unused(inputToInputWeights);
311 ignore_unused(recurrentToInputWeights);
312 ignore_unused(cellToInputWeights);
313 ignore_unused(inputGateBias);
314 ignore_unused(projectionWeights);
315 ignore_unused(projectionBias);
316 ignore_unused(cellToForgetWeights);
317 ignore_unused(cellToOutputWeights);
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000318 return IsSupportedForDataTypeRef(reasonIfUnsupported,
319 input.GetDataType(),
320 &TrueFunc<>,
321 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100322}
323
saoste012df12b32018-11-28 16:57:20 +0000324bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
325 const TensorInfo& input1,
326 const TensorInfo& output,
327 Optional<std::string&> reasonIfUnsupported) const
328{
329 ignore_unused(input1);
330 ignore_unused(output);
331 return IsSupportedForDataTypeRef(reasonIfUnsupported,
332 input0.GetDataType(),
333 &TrueFunc<>,
334 &TrueFunc<>);
335}
336
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100337bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
338 const TensorInfo& output,
339 const MeanDescriptor& descriptor,
340 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100341{
narpra011e4c31d2018-09-28 11:07:51 +0100342 ignore_unused(output);
343 ignore_unused(descriptor);
344 return IsSupportedForDataTypeRef(reasonIfUnsupported,
345 input.GetDataType(),
346 &TrueFunc<>,
347 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100348}
349
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100350bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000351 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100352 const OriginsDescriptor& descriptor,
353 Optional<std::string&> reasonIfUnsupported) const
354{
355 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000356 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100357 return IsSupportedForDataTypeRef(reasonIfUnsupported,
358 inputs[0]->GetDataType(),
359 &TrueFunc<>,
360 &TrueFunc<>);
361}
362
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000363bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
364 const TensorInfo& input1,
365 const TensorInfo& output,
366 Optional<std::string&> reasonIfUnsupported) const
367{
368 ignore_unused(input1);
369 ignore_unused(output);
370 return IsSupportedForDataTypeRef(reasonIfUnsupported,
371 input0.GetDataType(),
372 &TrueFunc<>,
373 &TrueFunc<>);
374}
375
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100376bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
377 const TensorInfo& input1,
378 const TensorInfo& output,
379 Optional<std::string&> reasonIfUnsupported) const
380{
381 ignore_unused(input1);
382 ignore_unused(output);
383 return IsSupportedForDataTypeRef(reasonIfUnsupported,
384 input0.GetDataType(),
385 &TrueFunc<>,
386 &TrueFunc<>);
387}
388
389bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
390 const TensorInfo& output,
391 const NormalizationDescriptor& descriptor,
392 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100393{
394 ignore_unused(output);
395 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100396 return IsSupportedForDataTypeRef(reasonIfUnsupported,
397 input.GetDataType(),
398 &TrueFunc<>,
399 &FalseFuncU8<>);
400}
401
402bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
403 Optional<std::string&> reasonIfUnsupported) const
404{
405 return IsSupportedForDataTypeRef(reasonIfUnsupported,
406 output.GetDataType(),
407 &TrueFunc<>,
408 &TrueFunc<>);
409}
410
411bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
412 const TensorInfo& output,
413 const PadDescriptor& descriptor,
414 Optional<std::string&> reasonIfUnsupported) const
415{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100416 ignore_unused(output);
417 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000418 return IsSupportedForDataTypeRef(reasonIfUnsupported,
419 input.GetDataType(),
420 &TrueFunc<>,
421 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100422}
423
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100424bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
425 const TensorInfo& output,
426 const PermuteDescriptor& descriptor,
427 Optional<std::string&> reasonIfUnsupported) const
428{
429 ignore_unused(output);
430 ignore_unused(descriptor);
431 return IsSupportedForDataTypeRef(reasonIfUnsupported,
432 input.GetDataType(),
433 &TrueFunc<>,
434 &TrueFunc<>);
435}
436
437bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
438 const TensorInfo& output,
439 const Pooling2dDescriptor& descriptor,
440 Optional<std::string&> reasonIfUnsupported) const
441{
442 ignore_unused(output);
443 ignore_unused(descriptor);
444 return IsSupportedForDataTypeRef(reasonIfUnsupported,
445 input.GetDataType(),
446 &TrueFunc<>,
447 &TrueFunc<>);
448}
449
450bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
451 Optional<std::string&> reasonIfUnsupported) const
452{
453 return IsSupportedForDataTypeRef(reasonIfUnsupported,
454 input.GetDataType(),
455 &TrueFunc<>,
456 &TrueFunc<>);
457}
458
459bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
460 Optional<std::string&> reasonIfUnsupported) const
461{
462 return IsSupportedForDataTypeRef(reasonIfUnsupported,
463 input.GetDataType(),
464 &TrueFunc<>,
465 &TrueFunc<>);
466}
467
468bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
469 const TensorInfo& output,
470 const SoftmaxDescriptor& descriptor,
471 Optional<std::string&> reasonIfUnsupported) const
472{
473 ignore_unused(output);
474 ignore_unused(descriptor);
475 return IsSupportedForDataTypeRef(reasonIfUnsupported,
476 input.GetDataType(),
477 &TrueFunc<>,
478 &TrueFunc<>);
479}
480
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000481bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
482 const TensorInfo& output,
483 const SpaceToBatchNdDescriptor& descriptor,
484 Optional<std::string&> reasonIfUnsupported) const
485{
486 ignore_unused(output);
487 ignore_unused(descriptor);
488 return IsSupportedForDataTypeRef(reasonIfUnsupported,
489 input.GetDataType(),
490 &TrueFunc<>,
491 &TrueFunc<>);
492}
493
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100494bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
495 const ViewsDescriptor& descriptor,
496 Optional<std::string&> reasonIfUnsupported) const
497{
498 ignore_unused(descriptor);
499 return IsSupportedForDataTypeRef(reasonIfUnsupported,
500 input.GetDataType(),
501 &TrueFunc<>,
502 &TrueFunc<>);
503}
504
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000505bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
506 const TensorInfo& output,
507 const StridedSliceDescriptor& descriptor,
508 Optional<std::string&> reasonIfUnsupported) const
509{
510 ignore_unused(output);
511 ignore_unused(descriptor);
512 return IsSupportedForDataTypeRef(reasonIfUnsupported,
513 input.GetDataType(),
514 &TrueFunc<>,
515 &TrueFunc<>);
516}
517
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100518bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
519 const TensorInfo& input1,
520 const TensorInfo& output,
521 Optional<std::string&> reasonIfUnsupported) const
522{
523 ignore_unused(input1);
524 ignore_unused(output);
525 return IsSupportedForDataTypeRef(reasonIfUnsupported,
526 input0.GetDataType(),
527 &TrueFunc<>,
528 &TrueFunc<>);
529}
530
arovir011c7c81b2018-10-08 11:34:28 +0100531} // namespace armnn