blob: 00e4c5c09ce715c4f7ae84465cff5882fec06682 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
David Beck111b5d92018-11-12 14:59:37 +000013#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010014
telsoa014fcda012018-03-09 14:13:49 +000015#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
17using namespace boost;
18
19namespace armnn
20{
21
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010022namespace
23{
24
25template<typename Float32Func, typename Uint8Func, typename ... Params>
26bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
27 DataType dataType,
28 Float32Func floatFuncPtr,
29 Uint8Func uint8FuncPtr,
30 Params&&... params)
31{
32 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
33 dataType,
34 &FalseFunc<Params...>,
35 floatFuncPtr,
36 uint8FuncPtr,
37 std::forward<Params>(params)...);
38}
39
40} // anonymous namespace
41
arovir011c7c81b2018-10-08 11:34:28 +010042bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
43 const TensorInfo& output,
44 const ActivationDescriptor& descriptor,
45 Optional<std::string&> reasonIfUnsupported) const
46{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010047 ignore_unused(output);
48 ignore_unused(descriptor);
49 return IsSupportedForDataTypeRef(reasonIfUnsupported,
50 input.GetDataType(),
51 &TrueFunc<>,
52 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010053}
54
55bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
56 const TensorInfo& input1,
57 const TensorInfo& output,
58 Optional<std::string&> reasonIfUnsupported) const
59{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010060 ignore_unused(input1);
61 ignore_unused(output);
62 return IsSupportedForDataTypeRef(reasonIfUnsupported,
63 input0.GetDataType(),
64 &TrueFunc<>,
65 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010066}
67
68bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
69 const TensorInfo& output,
70 const TensorInfo& mean,
71 const TensorInfo& var,
72 const TensorInfo& beta,
73 const TensorInfo& gamma,
74 const BatchNormalizationDescriptor& descriptor,
75 Optional<std::string&> reasonIfUnsupported) const
76{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010077 ignore_unused(output);
78 ignore_unused(mean);
79 ignore_unused(var);
80 ignore_unused(beta);
81 ignore_unused(gamma);
82 ignore_unused(descriptor);
83 return IsSupportedForDataTypeRef(reasonIfUnsupported,
84 input.GetDataType(),
85 &TrueFunc<>,
86 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010087}
88
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000089bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
90 const TensorInfo& output,
91 const BatchToSpaceNdDescriptor& descriptor,
92 Optional<std::string&> reasonIfUnsupported) const
93{
94 ignore_unused(descriptor);
95 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
96 input.GetDataType(),
97 &TrueFunc<>,
98 &TrueFunc<>) &&
99 IsSupportedForDataTypeRef(reasonIfUnsupported,
100 output.GetDataType(),
101 &TrueFunc<>,
102 &TrueFunc<>));
103}
104
arovir011c7c81b2018-10-08 11:34:28 +0100105bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
106 Optional<std::string&> reasonIfUnsupported) const
107{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100108 return IsSupportedForDataTypeRef(reasonIfUnsupported,
109 output.GetDataType(),
110 &TrueFunc<>,
111 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100112}
113
114bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
115 const TensorInfo& output,
116 Optional<std::string&> reasonIfUnsupported) const
117{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100118 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
119 input.GetDataType(),
120 &TrueFunc<>,
121 &FalseInputFuncF32<>,
122 &FalseFuncU8<>) &&
123 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
124 output.GetDataType(),
125 &FalseOutputFuncF16<>,
126 &TrueFunc<>,
127 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100128}
129
130bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
131 const TensorInfo& output,
132 Optional<std::string&> reasonIfUnsupported) const
133{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100134 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
135 input.GetDataType(),
136 &FalseInputFuncF16<>,
137 &TrueFunc<>,
138 &FalseFuncU8<>) &&
139 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
140 output.GetDataType(),
141 &TrueFunc<>,
142 &FalseOutputFuncF32<>,
143 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100144}
145
146bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
147 const TensorInfo& output,
148 const Convolution2dDescriptor& descriptor,
149 const TensorInfo& weights,
150 const Optional<TensorInfo>& biases,
151 Optional<std::string&> reasonIfUnsupported) const
152{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100153 ignore_unused(output);
154 ignore_unused(descriptor);
155 ignore_unused(weights);
156 ignore_unused(biases);
157 return IsSupportedForDataTypeRef(reasonIfUnsupported,
158 input.GetDataType(),
159 &TrueFunc<>,
160 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100161}
162
163bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
164 const TensorInfo& output,
165 const DepthwiseConvolution2dDescriptor& descriptor,
166 const TensorInfo& weights,
167 const Optional<TensorInfo>& biases,
168 Optional<std::string&> reasonIfUnsupported) const
169{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100170 ignore_unused(output);
171 ignore_unused(descriptor);
172 ignore_unused(weights);
173 ignore_unused(biases);
174 return IsSupportedForDataTypeRef(reasonIfUnsupported,
175 input.GetDataType(),
176 &TrueFunc<>,
177 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100178}
179
180bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
181 const TensorInfo& input1,
182 const TensorInfo& output,
183 Optional<std::string&> reasonIfUnsupported) const
184{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100185 ignore_unused(input1);
186 ignore_unused(output);
187 return IsSupportedForDataTypeRef(reasonIfUnsupported,
188 input0.GetDataType(),
189 &TrueFunc<>,
190 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100191}
192
193bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
194 const FakeQuantizationDescriptor& descriptor,
195 Optional<std::string&> reasonIfUnsupported) const
196{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100197 ignore_unused(descriptor);
198 return IsSupportedForDataTypeRef(reasonIfUnsupported,
199 input.GetDataType(),
200 &TrueFunc<>,
201 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100202}
203
204bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
205 const TensorInfo& output,
206 Optional<std::string&> reasonIfUnsupported) const
207{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100208 ignore_unused(output);
209 return IsSupportedForDataTypeRef(reasonIfUnsupported,
210 input.GetDataType(),
211 &TrueFunc<>,
212 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100213}
214
215bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
216 const TensorInfo& output,
217 const TensorInfo& weights,
218 const TensorInfo& biases,
219 const FullyConnectedDescriptor& descriptor,
220 Optional<std::string&> reasonIfUnsupported) const
221{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100222 ignore_unused(output);
223 ignore_unused(weights);
224 ignore_unused(biases);
225 ignore_unused(descriptor);
226 return IsSupportedForDataTypeRef(reasonIfUnsupported,
227 input.GetDataType(),
228 &TrueFunc<>,
229 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100230}
231
232bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
233 Optional<std::string&> reasonIfUnsupported) const
234{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100235 return IsSupportedForDataTypeRef(reasonIfUnsupported,
236 input.GetDataType(),
237 &TrueFunc<>,
238 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100239}
240
241bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
242 const TensorInfo& output,
243 const L2NormalizationDescriptor& descriptor,
244 Optional<std::string&> reasonIfUnsupported) const
245{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100246 ignore_unused(output);
247 ignore_unused(descriptor);
248 return IsSupportedForDataTypeRef(reasonIfUnsupported,
249 input.GetDataType(),
250 &TrueFunc<>,
251 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100252}
253
254bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
255 const TensorInfo& outputStateIn,
256 const TensorInfo& cellStateIn,
257 const TensorInfo& scratchBuffer,
258 const TensorInfo& outputStateOut,
259 const TensorInfo& cellStateOut,
260 const TensorInfo& output,
261 const LstmDescriptor& descriptor,
262 const TensorInfo& inputToForgetWeights,
263 const TensorInfo& inputToCellWeights,
264 const TensorInfo& inputToOutputWeights,
265 const TensorInfo& recurrentToForgetWeights,
266 const TensorInfo& recurrentToCellWeights,
267 const TensorInfo& recurrentToOutputWeights,
268 const TensorInfo& forgetGateBias,
269 const TensorInfo& cellBias,
270 const TensorInfo& outputGateBias,
271 const TensorInfo* inputToInputWeights,
272 const TensorInfo* recurrentToInputWeights,
273 const TensorInfo* cellToInputWeights,
274 const TensorInfo* inputGateBias,
275 const TensorInfo* projectionWeights,
276 const TensorInfo* projectionBias,
277 const TensorInfo* cellToForgetWeights,
278 const TensorInfo* cellToOutputWeights,
279 Optional<std::string&> reasonIfUnsupported) const
280{
telsoa01c577f2c2018-08-31 09:22:23 +0100281 ignore_unused(outputStateIn);
282 ignore_unused(cellStateIn);
283 ignore_unused(scratchBuffer);
284 ignore_unused(outputStateOut);
285 ignore_unused(cellStateOut);
286 ignore_unused(output);
287 ignore_unused(descriptor);
288 ignore_unused(inputToForgetWeights);
289 ignore_unused(inputToCellWeights);
290 ignore_unused(inputToOutputWeights);
291 ignore_unused(recurrentToForgetWeights);
292 ignore_unused(recurrentToCellWeights);
293 ignore_unused(recurrentToOutputWeights);
294 ignore_unused(forgetGateBias);
295 ignore_unused(cellBias);
296 ignore_unused(outputGateBias);
297 ignore_unused(inputToInputWeights);
298 ignore_unused(recurrentToInputWeights);
299 ignore_unused(cellToInputWeights);
300 ignore_unused(inputGateBias);
301 ignore_unused(projectionWeights);
302 ignore_unused(projectionBias);
303 ignore_unused(cellToForgetWeights);
304 ignore_unused(cellToOutputWeights);
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000305 return IsSupportedForDataTypeRef(reasonIfUnsupported,
306 input.GetDataType(),
307 &TrueFunc<>,
308 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100309}
310
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100311bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
312 const TensorInfo& output,
313 const MeanDescriptor& descriptor,
314 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100315{
narpra011e4c31d2018-09-28 11:07:51 +0100316 ignore_unused(output);
317 ignore_unused(descriptor);
318 return IsSupportedForDataTypeRef(reasonIfUnsupported,
319 input.GetDataType(),
320 &TrueFunc<>,
321 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100322}
323
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100324bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000325 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100326 const OriginsDescriptor& descriptor,
327 Optional<std::string&> reasonIfUnsupported) const
328{
329 ignore_unused(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +0000330 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100331 return IsSupportedForDataTypeRef(reasonIfUnsupported,
332 inputs[0]->GetDataType(),
333 &TrueFunc<>,
334 &TrueFunc<>);
335}
336
337bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
338 const TensorInfo& input1,
339 const TensorInfo& output,
340 Optional<std::string&> reasonIfUnsupported) const
341{
342 ignore_unused(input1);
343 ignore_unused(output);
344 return IsSupportedForDataTypeRef(reasonIfUnsupported,
345 input0.GetDataType(),
346 &TrueFunc<>,
347 &TrueFunc<>);
348}
349
350bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
351 const TensorInfo& output,
352 const NormalizationDescriptor& descriptor,
353 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100354{
355 ignore_unused(output);
356 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100357 return IsSupportedForDataTypeRef(reasonIfUnsupported,
358 input.GetDataType(),
359 &TrueFunc<>,
360 &FalseFuncU8<>);
361}
362
363bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
364 Optional<std::string&> reasonIfUnsupported) const
365{
366 return IsSupportedForDataTypeRef(reasonIfUnsupported,
367 output.GetDataType(),
368 &TrueFunc<>,
369 &TrueFunc<>);
370}
371
372bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
373 const TensorInfo& output,
374 const PadDescriptor& descriptor,
375 Optional<std::string&> reasonIfUnsupported) const
376{
377 ignore_unused(input);
378 ignore_unused(output);
379 ignore_unused(descriptor);
380 ignore_unused(reasonIfUnsupported);
Nina Drozd661dfa72018-10-02 11:14:17 +0100381 return false;
382}
383
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100384bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
385 const TensorInfo& output,
386 const PermuteDescriptor& descriptor,
387 Optional<std::string&> reasonIfUnsupported) const
388{
389 ignore_unused(output);
390 ignore_unused(descriptor);
391 return IsSupportedForDataTypeRef(reasonIfUnsupported,
392 input.GetDataType(),
393 &TrueFunc<>,
394 &TrueFunc<>);
395}
396
397bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
398 const TensorInfo& output,
399 const Pooling2dDescriptor& descriptor,
400 Optional<std::string&> reasonIfUnsupported) const
401{
402 ignore_unused(output);
403 ignore_unused(descriptor);
404 return IsSupportedForDataTypeRef(reasonIfUnsupported,
405 input.GetDataType(),
406 &TrueFunc<>,
407 &TrueFunc<>);
408}
409
410bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
411 Optional<std::string&> reasonIfUnsupported) const
412{
413 return IsSupportedForDataTypeRef(reasonIfUnsupported,
414 input.GetDataType(),
415 &TrueFunc<>,
416 &TrueFunc<>);
417}
418
419bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
420 Optional<std::string&> reasonIfUnsupported) const
421{
422 return IsSupportedForDataTypeRef(reasonIfUnsupported,
423 input.GetDataType(),
424 &TrueFunc<>,
425 &TrueFunc<>);
426}
427
428bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
429 const TensorInfo& output,
430 const SoftmaxDescriptor& descriptor,
431 Optional<std::string&> reasonIfUnsupported) const
432{
433 ignore_unused(output);
434 ignore_unused(descriptor);
435 return IsSupportedForDataTypeRef(reasonIfUnsupported,
436 input.GetDataType(),
437 &TrueFunc<>,
438 &TrueFunc<>);
439}
440
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000441bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
442 const TensorInfo& output,
443 const SpaceToBatchNdDescriptor& descriptor,
444 Optional<std::string&> reasonIfUnsupported) const
445{
446 ignore_unused(output);
447 ignore_unused(descriptor);
448 return IsSupportedForDataTypeRef(reasonIfUnsupported,
449 input.GetDataType(),
450 &TrueFunc<>,
451 &TrueFunc<>);
452}
453
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100454bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
455 const ViewsDescriptor& descriptor,
456 Optional<std::string&> reasonIfUnsupported) const
457{
458 ignore_unused(descriptor);
459 return IsSupportedForDataTypeRef(reasonIfUnsupported,
460 input.GetDataType(),
461 &TrueFunc<>,
462 &TrueFunc<>);
463}
464
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000465bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
466 const TensorInfo& output,
467 const StridedSliceDescriptor& descriptor,
468 Optional<std::string&> reasonIfUnsupported) const
469{
470 ignore_unused(output);
471 ignore_unused(descriptor);
472 return IsSupportedForDataTypeRef(reasonIfUnsupported,
473 input.GetDataType(),
474 &TrueFunc<>,
475 &TrueFunc<>);
476}
477
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100478bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
479 const TensorInfo& input1,
480 const TensorInfo& output,
481 Optional<std::string&> reasonIfUnsupported) const
482{
483 ignore_unused(input1);
484 ignore_unused(output);
485 return IsSupportedForDataTypeRef(reasonIfUnsupported,
486 input0.GetDataType(),
487 &TrueFunc<>,
488 &TrueFunc<>);
489}
490
arovir011c7c81b2018-10-08 11:34:28 +0100491} // namespace armnn