blob: b057370459eafbb23aead9d6049d988f1aab3e7d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <backendsCommon/LayerSupportRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010014
telsoa014fcda012018-03-09 14:13:49 +000015#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
17using namespace boost;
18
19namespace armnn
20{
21
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010022namespace
23{
24
David Beck3e9e1152018-10-17 14:17:50 +010025ILayerSupportSharedPtr GetLayerSupportPointer()
26{
27 static ILayerSupportSharedPtr instance{new RefLayerSupport};
28 return instance;
29}
30
31static StaticRegistryInitializer<LayerSupportRegistry> g_RegisterHelper{
32 LayerSupportRegistryInstance(),
33 RefBackendId(),
David Beck9efb57d2018-11-05 13:40:33 +000034 []()
David Beck3e9e1152018-10-17 14:17:50 +010035 {
36 return GetLayerSupportPointer();
37 }
38};
39
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040template<typename Float32Func, typename Uint8Func, typename ... Params>
41bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
42 DataType dataType,
43 Float32Func floatFuncPtr,
44 Uint8Func uint8FuncPtr,
45 Params&&... params)
46{
47 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
48 dataType,
49 &FalseFunc<Params...>,
50 floatFuncPtr,
51 uint8FuncPtr,
52 std::forward<Params>(params)...);
53}
54
55} // anonymous namespace
56
arovir011c7c81b2018-10-08 11:34:28 +010057bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
58 const TensorInfo& output,
59 const ActivationDescriptor& descriptor,
60 Optional<std::string&> reasonIfUnsupported) const
61{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010062 ignore_unused(output);
63 ignore_unused(descriptor);
64 return IsSupportedForDataTypeRef(reasonIfUnsupported,
65 input.GetDataType(),
66 &TrueFunc<>,
67 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010068}
69
70bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
71 const TensorInfo& input1,
72 const TensorInfo& output,
73 Optional<std::string&> reasonIfUnsupported) const
74{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010075 ignore_unused(input1);
76 ignore_unused(output);
77 return IsSupportedForDataTypeRef(reasonIfUnsupported,
78 input0.GetDataType(),
79 &TrueFunc<>,
80 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010081}
82
83bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
84 const TensorInfo& output,
85 const TensorInfo& mean,
86 const TensorInfo& var,
87 const TensorInfo& beta,
88 const TensorInfo& gamma,
89 const BatchNormalizationDescriptor& descriptor,
90 Optional<std::string&> reasonIfUnsupported) const
91{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010092 ignore_unused(output);
93 ignore_unused(mean);
94 ignore_unused(var);
95 ignore_unused(beta);
96 ignore_unused(gamma);
97 ignore_unused(descriptor);
98 return IsSupportedForDataTypeRef(reasonIfUnsupported,
99 input.GetDataType(),
100 &TrueFunc<>,
101 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100102}
103
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000104bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
105 const TensorInfo& output,
106 const BatchToSpaceNdDescriptor& descriptor,
107 Optional<std::string&> reasonIfUnsupported) const
108{
109 ignore_unused(descriptor);
110 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
111 input.GetDataType(),
112 &TrueFunc<>,
113 &TrueFunc<>) &&
114 IsSupportedForDataTypeRef(reasonIfUnsupported,
115 output.GetDataType(),
116 &TrueFunc<>,
117 &TrueFunc<>));
118}
119
arovir011c7c81b2018-10-08 11:34:28 +0100120bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
121 Optional<std::string&> reasonIfUnsupported) const
122{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100123 return IsSupportedForDataTypeRef(reasonIfUnsupported,
124 output.GetDataType(),
125 &TrueFunc<>,
126 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100127}
128
129bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
130 const TensorInfo& output,
131 Optional<std::string&> reasonIfUnsupported) const
132{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100133 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
134 input.GetDataType(),
135 &TrueFunc<>,
136 &FalseInputFuncF32<>,
137 &FalseFuncU8<>) &&
138 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
139 output.GetDataType(),
140 &FalseOutputFuncF16<>,
141 &TrueFunc<>,
142 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100143}
144
145bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
146 const TensorInfo& output,
147 Optional<std::string&> reasonIfUnsupported) const
148{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100149 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
150 input.GetDataType(),
151 &FalseInputFuncF16<>,
152 &TrueFunc<>,
153 &FalseFuncU8<>) &&
154 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
155 output.GetDataType(),
156 &TrueFunc<>,
157 &FalseOutputFuncF32<>,
158 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100159}
160
161bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
162 const TensorInfo& output,
163 const Convolution2dDescriptor& descriptor,
164 const TensorInfo& weights,
165 const Optional<TensorInfo>& biases,
166 Optional<std::string&> reasonIfUnsupported) const
167{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100168 ignore_unused(output);
169 ignore_unused(descriptor);
170 ignore_unused(weights);
171 ignore_unused(biases);
172 return IsSupportedForDataTypeRef(reasonIfUnsupported,
173 input.GetDataType(),
174 &TrueFunc<>,
175 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100176}
177
178bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
179 const TensorInfo& output,
180 const DepthwiseConvolution2dDescriptor& descriptor,
181 const TensorInfo& weights,
182 const Optional<TensorInfo>& biases,
183 Optional<std::string&> reasonIfUnsupported) const
184{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100185 ignore_unused(output);
186 ignore_unused(descriptor);
187 ignore_unused(weights);
188 ignore_unused(biases);
189 return IsSupportedForDataTypeRef(reasonIfUnsupported,
190 input.GetDataType(),
191 &TrueFunc<>,
192 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100193}
194
195bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
196 const TensorInfo& input1,
197 const TensorInfo& output,
198 Optional<std::string&> reasonIfUnsupported) const
199{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100200 ignore_unused(input1);
201 ignore_unused(output);
202 return IsSupportedForDataTypeRef(reasonIfUnsupported,
203 input0.GetDataType(),
204 &TrueFunc<>,
205 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100206}
207
208bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
209 const FakeQuantizationDescriptor& descriptor,
210 Optional<std::string&> reasonIfUnsupported) const
211{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100212 ignore_unused(descriptor);
213 return IsSupportedForDataTypeRef(reasonIfUnsupported,
214 input.GetDataType(),
215 &TrueFunc<>,
216 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100217}
218
219bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
220 const TensorInfo& output,
221 Optional<std::string&> reasonIfUnsupported) const
222{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100223 ignore_unused(output);
224 return IsSupportedForDataTypeRef(reasonIfUnsupported,
225 input.GetDataType(),
226 &TrueFunc<>,
227 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100228}
229
230bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
231 const TensorInfo& output,
232 const TensorInfo& weights,
233 const TensorInfo& biases,
234 const FullyConnectedDescriptor& descriptor,
235 Optional<std::string&> reasonIfUnsupported) const
236{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100237 ignore_unused(output);
238 ignore_unused(weights);
239 ignore_unused(biases);
240 ignore_unused(descriptor);
241 return IsSupportedForDataTypeRef(reasonIfUnsupported,
242 input.GetDataType(),
243 &TrueFunc<>,
244 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100245}
246
247bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
248 Optional<std::string&> reasonIfUnsupported) const
249{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100250 return IsSupportedForDataTypeRef(reasonIfUnsupported,
251 input.GetDataType(),
252 &TrueFunc<>,
253 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100254}
255
256bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
257 const TensorInfo& output,
258 const L2NormalizationDescriptor& descriptor,
259 Optional<std::string&> reasonIfUnsupported) const
260{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100261 ignore_unused(output);
262 ignore_unused(descriptor);
263 return IsSupportedForDataTypeRef(reasonIfUnsupported,
264 input.GetDataType(),
265 &TrueFunc<>,
266 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100267}
268
269bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
270 const TensorInfo& outputStateIn,
271 const TensorInfo& cellStateIn,
272 const TensorInfo& scratchBuffer,
273 const TensorInfo& outputStateOut,
274 const TensorInfo& cellStateOut,
275 const TensorInfo& output,
276 const LstmDescriptor& descriptor,
277 const TensorInfo& inputToForgetWeights,
278 const TensorInfo& inputToCellWeights,
279 const TensorInfo& inputToOutputWeights,
280 const TensorInfo& recurrentToForgetWeights,
281 const TensorInfo& recurrentToCellWeights,
282 const TensorInfo& recurrentToOutputWeights,
283 const TensorInfo& forgetGateBias,
284 const TensorInfo& cellBias,
285 const TensorInfo& outputGateBias,
286 const TensorInfo* inputToInputWeights,
287 const TensorInfo* recurrentToInputWeights,
288 const TensorInfo* cellToInputWeights,
289 const TensorInfo* inputGateBias,
290 const TensorInfo* projectionWeights,
291 const TensorInfo* projectionBias,
292 const TensorInfo* cellToForgetWeights,
293 const TensorInfo* cellToOutputWeights,
294 Optional<std::string&> reasonIfUnsupported) const
295{
telsoa01c577f2c2018-08-31 09:22:23 +0100296 ignore_unused(input);
297 ignore_unused(outputStateIn);
298 ignore_unused(cellStateIn);
299 ignore_unused(scratchBuffer);
300 ignore_unused(outputStateOut);
301 ignore_unused(cellStateOut);
302 ignore_unused(output);
303 ignore_unused(descriptor);
304 ignore_unused(inputToForgetWeights);
305 ignore_unused(inputToCellWeights);
306 ignore_unused(inputToOutputWeights);
307 ignore_unused(recurrentToForgetWeights);
308 ignore_unused(recurrentToCellWeights);
309 ignore_unused(recurrentToOutputWeights);
310 ignore_unused(forgetGateBias);
311 ignore_unused(cellBias);
312 ignore_unused(outputGateBias);
313 ignore_unused(inputToInputWeights);
314 ignore_unused(recurrentToInputWeights);
315 ignore_unused(cellToInputWeights);
316 ignore_unused(inputGateBias);
317 ignore_unused(projectionWeights);
318 ignore_unused(projectionBias);
319 ignore_unused(cellToForgetWeights);
320 ignore_unused(cellToOutputWeights);
arovir01085f0a42018-10-08 14:48:19 +0100321 ignore_unused(reasonIfUnsupported);
telsoa01c577f2c2018-08-31 09:22:23 +0100322 return false;
323}
324
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100325bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
326 const TensorInfo& output,
327 const MeanDescriptor& descriptor,
328 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100329{
narpra011e4c31d2018-09-28 11:07:51 +0100330 ignore_unused(output);
331 ignore_unused(descriptor);
332 return IsSupportedForDataTypeRef(reasonIfUnsupported,
333 input.GetDataType(),
334 &TrueFunc<>,
335 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100336}
337
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100338bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
339 const OriginsDescriptor& descriptor,
340 Optional<std::string&> reasonIfUnsupported) const
341{
342 ignore_unused(descriptor);
343 return IsSupportedForDataTypeRef(reasonIfUnsupported,
344 inputs[0]->GetDataType(),
345 &TrueFunc<>,
346 &TrueFunc<>);
347}
348
349bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
350 const TensorInfo& input1,
351 const TensorInfo& output,
352 Optional<std::string&> reasonIfUnsupported) const
353{
354 ignore_unused(input1);
355 ignore_unused(output);
356 return IsSupportedForDataTypeRef(reasonIfUnsupported,
357 input0.GetDataType(),
358 &TrueFunc<>,
359 &TrueFunc<>);
360}
361
362bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
363 const TensorInfo& output,
364 const NormalizationDescriptor& descriptor,
365 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100366{
367 ignore_unused(output);
368 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100369 return IsSupportedForDataTypeRef(reasonIfUnsupported,
370 input.GetDataType(),
371 &TrueFunc<>,
372 &FalseFuncU8<>);
373}
374
375bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
376 Optional<std::string&> reasonIfUnsupported) const
377{
378 return IsSupportedForDataTypeRef(reasonIfUnsupported,
379 output.GetDataType(),
380 &TrueFunc<>,
381 &TrueFunc<>);
382}
383
384bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
385 const TensorInfo& output,
386 const PadDescriptor& descriptor,
387 Optional<std::string&> reasonIfUnsupported) const
388{
389 ignore_unused(input);
390 ignore_unused(output);
391 ignore_unused(descriptor);
392 ignore_unused(reasonIfUnsupported);
Nina Drozd661dfa72018-10-02 11:14:17 +0100393 return false;
394}
395
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100396bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
397 const TensorInfo& output,
398 const PermuteDescriptor& descriptor,
399 Optional<std::string&> reasonIfUnsupported) const
400{
401 ignore_unused(output);
402 ignore_unused(descriptor);
403 return IsSupportedForDataTypeRef(reasonIfUnsupported,
404 input.GetDataType(),
405 &TrueFunc<>,
406 &TrueFunc<>);
407}
408
409bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
410 const TensorInfo& output,
411 const Pooling2dDescriptor& descriptor,
412 Optional<std::string&> reasonIfUnsupported) const
413{
414 ignore_unused(output);
415 ignore_unused(descriptor);
416 return IsSupportedForDataTypeRef(reasonIfUnsupported,
417 input.GetDataType(),
418 &TrueFunc<>,
419 &TrueFunc<>);
420}
421
422bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
423 Optional<std::string&> reasonIfUnsupported) const
424{
425 return IsSupportedForDataTypeRef(reasonIfUnsupported,
426 input.GetDataType(),
427 &TrueFunc<>,
428 &TrueFunc<>);
429}
430
431bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
432 Optional<std::string&> reasonIfUnsupported) const
433{
434 return IsSupportedForDataTypeRef(reasonIfUnsupported,
435 input.GetDataType(),
436 &TrueFunc<>,
437 &TrueFunc<>);
438}
439
440bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
441 const TensorInfo& output,
442 const SoftmaxDescriptor& descriptor,
443 Optional<std::string&> reasonIfUnsupported) const
444{
445 ignore_unused(output);
446 ignore_unused(descriptor);
447 return IsSupportedForDataTypeRef(reasonIfUnsupported,
448 input.GetDataType(),
449 &TrueFunc<>,
450 &TrueFunc<>);
451}
452
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000453bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
454 const TensorInfo& output,
455 const SpaceToBatchNdDescriptor& descriptor,
456 Optional<std::string&> reasonIfUnsupported) const
457{
458 ignore_unused(output);
459 ignore_unused(descriptor);
460 return IsSupportedForDataTypeRef(reasonIfUnsupported,
461 input.GetDataType(),
462 &TrueFunc<>,
463 &TrueFunc<>);
464}
465
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100466bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
467 const ViewsDescriptor& descriptor,
468 Optional<std::string&> reasonIfUnsupported) const
469{
470 ignore_unused(descriptor);
471 return IsSupportedForDataTypeRef(reasonIfUnsupported,
472 input.GetDataType(),
473 &TrueFunc<>,
474 &TrueFunc<>);
475}
476
477bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
478 const TensorInfo& input1,
479 const TensorInfo& output,
480 Optional<std::string&> reasonIfUnsupported) const
481{
482 ignore_unused(input1);
483 ignore_unused(output);
484 return IsSupportedForDataTypeRef(reasonIfUnsupported,
485 input0.GetDataType(),
486 &TrueFunc<>,
487 &TrueFunc<>);
488}
489
arovir011c7c81b2018-10-08 11:34:28 +0100490} // namespace armnn