blob: 0902b0fd17e41877c4a5e35419233023a1a6ef0d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <backendsCommon/LayerSupportRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010014
telsoa014fcda012018-03-09 14:13:49 +000015#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
17using namespace boost;
18
19namespace armnn
20{
21
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010022namespace
23{
24
David Beck3e9e1152018-10-17 14:17:50 +010025ILayerSupportSharedPtr GetLayerSupportPointer()
26{
27 static ILayerSupportSharedPtr instance{new RefLayerSupport};
28 return instance;
29}
30
31static StaticRegistryInitializer<LayerSupportRegistry> g_RegisterHelper{
32 LayerSupportRegistryInstance(),
33 RefBackendId(),
David Beck9efb57d2018-11-05 13:40:33 +000034 []()
David Beck3e9e1152018-10-17 14:17:50 +010035 {
36 return GetLayerSupportPointer();
37 }
38};
39
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040template<typename Float32Func, typename Uint8Func, typename ... Params>
41bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
42 DataType dataType,
43 Float32Func floatFuncPtr,
44 Uint8Func uint8FuncPtr,
45 Params&&... params)
46{
47 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
48 dataType,
49 &FalseFunc<Params...>,
50 floatFuncPtr,
51 uint8FuncPtr,
52 std::forward<Params>(params)...);
53}
54
55} // anonymous namespace
56
arovir011c7c81b2018-10-08 11:34:28 +010057bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
58 const TensorInfo& output,
59 const ActivationDescriptor& descriptor,
60 Optional<std::string&> reasonIfUnsupported) const
61{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010062 ignore_unused(output);
63 ignore_unused(descriptor);
64 return IsSupportedForDataTypeRef(reasonIfUnsupported,
65 input.GetDataType(),
66 &TrueFunc<>,
67 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010068}
69
70bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
71 const TensorInfo& input1,
72 const TensorInfo& output,
73 Optional<std::string&> reasonIfUnsupported) const
74{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010075 ignore_unused(input1);
76 ignore_unused(output);
77 return IsSupportedForDataTypeRef(reasonIfUnsupported,
78 input0.GetDataType(),
79 &TrueFunc<>,
80 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010081}
82
83bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
84 const TensorInfo& output,
85 const TensorInfo& mean,
86 const TensorInfo& var,
87 const TensorInfo& beta,
88 const TensorInfo& gamma,
89 const BatchNormalizationDescriptor& descriptor,
90 Optional<std::string&> reasonIfUnsupported) const
91{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010092 ignore_unused(output);
93 ignore_unused(mean);
94 ignore_unused(var);
95 ignore_unused(beta);
96 ignore_unused(gamma);
97 ignore_unused(descriptor);
98 return IsSupportedForDataTypeRef(reasonIfUnsupported,
99 input.GetDataType(),
100 &TrueFunc<>,
101 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100102}
103
104bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
105 Optional<std::string&> reasonIfUnsupported) const
106{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100107 return IsSupportedForDataTypeRef(reasonIfUnsupported,
108 output.GetDataType(),
109 &TrueFunc<>,
110 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100111}
112
113bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
114 const TensorInfo& output,
115 Optional<std::string&> reasonIfUnsupported) const
116{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100117 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
118 input.GetDataType(),
119 &TrueFunc<>,
120 &FalseInputFuncF32<>,
121 &FalseFuncU8<>) &&
122 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
123 output.GetDataType(),
124 &FalseOutputFuncF16<>,
125 &TrueFunc<>,
126 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100127}
128
129bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
130 const TensorInfo& output,
131 Optional<std::string&> reasonIfUnsupported) const
132{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100133 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
134 input.GetDataType(),
135 &FalseInputFuncF16<>,
136 &TrueFunc<>,
137 &FalseFuncU8<>) &&
138 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
139 output.GetDataType(),
140 &TrueFunc<>,
141 &FalseOutputFuncF32<>,
142 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100143}
144
145bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
146 const TensorInfo& output,
147 const Convolution2dDescriptor& descriptor,
148 const TensorInfo& weights,
149 const Optional<TensorInfo>& biases,
150 Optional<std::string&> reasonIfUnsupported) const
151{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100152 ignore_unused(output);
153 ignore_unused(descriptor);
154 ignore_unused(weights);
155 ignore_unused(biases);
156 return IsSupportedForDataTypeRef(reasonIfUnsupported,
157 input.GetDataType(),
158 &TrueFunc<>,
159 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100160}
161
162bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
163 const TensorInfo& output,
164 const DepthwiseConvolution2dDescriptor& descriptor,
165 const TensorInfo& weights,
166 const Optional<TensorInfo>& biases,
167 Optional<std::string&> reasonIfUnsupported) const
168{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100169 ignore_unused(output);
170 ignore_unused(descriptor);
171 ignore_unused(weights);
172 ignore_unused(biases);
173 return IsSupportedForDataTypeRef(reasonIfUnsupported,
174 input.GetDataType(),
175 &TrueFunc<>,
176 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100177}
178
179bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
180 const TensorInfo& input1,
181 const TensorInfo& output,
182 Optional<std::string&> reasonIfUnsupported) const
183{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100184 ignore_unused(input1);
185 ignore_unused(output);
186 return IsSupportedForDataTypeRef(reasonIfUnsupported,
187 input0.GetDataType(),
188 &TrueFunc<>,
189 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100190}
191
192bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
193 const FakeQuantizationDescriptor& descriptor,
194 Optional<std::string&> reasonIfUnsupported) const
195{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100196 ignore_unused(descriptor);
197 return IsSupportedForDataTypeRef(reasonIfUnsupported,
198 input.GetDataType(),
199 &TrueFunc<>,
200 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100201}
202
203bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
204 const TensorInfo& output,
205 Optional<std::string&> reasonIfUnsupported) const
206{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100207 ignore_unused(output);
208 return IsSupportedForDataTypeRef(reasonIfUnsupported,
209 input.GetDataType(),
210 &TrueFunc<>,
211 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100212}
213
214bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
215 const TensorInfo& output,
216 const TensorInfo& weights,
217 const TensorInfo& biases,
218 const FullyConnectedDescriptor& descriptor,
219 Optional<std::string&> reasonIfUnsupported) const
220{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100221 ignore_unused(output);
222 ignore_unused(weights);
223 ignore_unused(biases);
224 ignore_unused(descriptor);
225 return IsSupportedForDataTypeRef(reasonIfUnsupported,
226 input.GetDataType(),
227 &TrueFunc<>,
228 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100229}
230
231bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
232 Optional<std::string&> reasonIfUnsupported) const
233{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100234 return IsSupportedForDataTypeRef(reasonIfUnsupported,
235 input.GetDataType(),
236 &TrueFunc<>,
237 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100238}
239
240bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
241 const TensorInfo& output,
242 const L2NormalizationDescriptor& descriptor,
243 Optional<std::string&> reasonIfUnsupported) const
244{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100245 ignore_unused(output);
246 ignore_unused(descriptor);
247 return IsSupportedForDataTypeRef(reasonIfUnsupported,
248 input.GetDataType(),
249 &TrueFunc<>,
250 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100251}
252
253bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
254 const TensorInfo& outputStateIn,
255 const TensorInfo& cellStateIn,
256 const TensorInfo& scratchBuffer,
257 const TensorInfo& outputStateOut,
258 const TensorInfo& cellStateOut,
259 const TensorInfo& output,
260 const LstmDescriptor& descriptor,
261 const TensorInfo& inputToForgetWeights,
262 const TensorInfo& inputToCellWeights,
263 const TensorInfo& inputToOutputWeights,
264 const TensorInfo& recurrentToForgetWeights,
265 const TensorInfo& recurrentToCellWeights,
266 const TensorInfo& recurrentToOutputWeights,
267 const TensorInfo& forgetGateBias,
268 const TensorInfo& cellBias,
269 const TensorInfo& outputGateBias,
270 const TensorInfo* inputToInputWeights,
271 const TensorInfo* recurrentToInputWeights,
272 const TensorInfo* cellToInputWeights,
273 const TensorInfo* inputGateBias,
274 const TensorInfo* projectionWeights,
275 const TensorInfo* projectionBias,
276 const TensorInfo* cellToForgetWeights,
277 const TensorInfo* cellToOutputWeights,
278 Optional<std::string&> reasonIfUnsupported) const
279{
telsoa01c577f2c2018-08-31 09:22:23 +0100280 ignore_unused(input);
281 ignore_unused(outputStateIn);
282 ignore_unused(cellStateIn);
283 ignore_unused(scratchBuffer);
284 ignore_unused(outputStateOut);
285 ignore_unused(cellStateOut);
286 ignore_unused(output);
287 ignore_unused(descriptor);
288 ignore_unused(inputToForgetWeights);
289 ignore_unused(inputToCellWeights);
290 ignore_unused(inputToOutputWeights);
291 ignore_unused(recurrentToForgetWeights);
292 ignore_unused(recurrentToCellWeights);
293 ignore_unused(recurrentToOutputWeights);
294 ignore_unused(forgetGateBias);
295 ignore_unused(cellBias);
296 ignore_unused(outputGateBias);
297 ignore_unused(inputToInputWeights);
298 ignore_unused(recurrentToInputWeights);
299 ignore_unused(cellToInputWeights);
300 ignore_unused(inputGateBias);
301 ignore_unused(projectionWeights);
302 ignore_unused(projectionBias);
303 ignore_unused(cellToForgetWeights);
304 ignore_unused(cellToOutputWeights);
arovir01085f0a42018-10-08 14:48:19 +0100305 ignore_unused(reasonIfUnsupported);
telsoa01c577f2c2018-08-31 09:22:23 +0100306 return false;
307}
308
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100309bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
310 const TensorInfo& output,
311 const MeanDescriptor& descriptor,
312 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100313{
narpra011e4c31d2018-09-28 11:07:51 +0100314 ignore_unused(output);
315 ignore_unused(descriptor);
316 return IsSupportedForDataTypeRef(reasonIfUnsupported,
317 input.GetDataType(),
318 &TrueFunc<>,
319 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100320}
321
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100322bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
323 const OriginsDescriptor& descriptor,
324 Optional<std::string&> reasonIfUnsupported) const
325{
326 ignore_unused(descriptor);
327 return IsSupportedForDataTypeRef(reasonIfUnsupported,
328 inputs[0]->GetDataType(),
329 &TrueFunc<>,
330 &TrueFunc<>);
331}
332
333bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
334 const TensorInfo& input1,
335 const TensorInfo& output,
336 Optional<std::string&> reasonIfUnsupported) const
337{
338 ignore_unused(input1);
339 ignore_unused(output);
340 return IsSupportedForDataTypeRef(reasonIfUnsupported,
341 input0.GetDataType(),
342 &TrueFunc<>,
343 &TrueFunc<>);
344}
345
346bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
347 const TensorInfo& output,
348 const NormalizationDescriptor& descriptor,
349 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100350{
351 ignore_unused(output);
352 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100353 return IsSupportedForDataTypeRef(reasonIfUnsupported,
354 input.GetDataType(),
355 &TrueFunc<>,
356 &FalseFuncU8<>);
357}
358
359bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
360 Optional<std::string&> reasonIfUnsupported) const
361{
362 return IsSupportedForDataTypeRef(reasonIfUnsupported,
363 output.GetDataType(),
364 &TrueFunc<>,
365 &TrueFunc<>);
366}
367
368bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
369 const TensorInfo& output,
370 const PadDescriptor& descriptor,
371 Optional<std::string&> reasonIfUnsupported) const
372{
373 ignore_unused(input);
374 ignore_unused(output);
375 ignore_unused(descriptor);
376 ignore_unused(reasonIfUnsupported);
Nina Drozd661dfa72018-10-02 11:14:17 +0100377 return false;
378}
379
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100380bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
381 const TensorInfo& output,
382 const PermuteDescriptor& descriptor,
383 Optional<std::string&> reasonIfUnsupported) const
384{
385 ignore_unused(output);
386 ignore_unused(descriptor);
387 return IsSupportedForDataTypeRef(reasonIfUnsupported,
388 input.GetDataType(),
389 &TrueFunc<>,
390 &TrueFunc<>);
391}
392
393bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
394 const TensorInfo& output,
395 const Pooling2dDescriptor& descriptor,
396 Optional<std::string&> reasonIfUnsupported) const
397{
398 ignore_unused(output);
399 ignore_unused(descriptor);
400 return IsSupportedForDataTypeRef(reasonIfUnsupported,
401 input.GetDataType(),
402 &TrueFunc<>,
403 &TrueFunc<>);
404}
405
406bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
407 Optional<std::string&> reasonIfUnsupported) const
408{
409 return IsSupportedForDataTypeRef(reasonIfUnsupported,
410 input.GetDataType(),
411 &TrueFunc<>,
412 &TrueFunc<>);
413}
414
415bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
416 Optional<std::string&> reasonIfUnsupported) const
417{
418 return IsSupportedForDataTypeRef(reasonIfUnsupported,
419 input.GetDataType(),
420 &TrueFunc<>,
421 &TrueFunc<>);
422}
423
424bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
425 const TensorInfo& output,
426 const SoftmaxDescriptor& descriptor,
427 Optional<std::string&> reasonIfUnsupported) const
428{
429 ignore_unused(output);
430 ignore_unused(descriptor);
431 return IsSupportedForDataTypeRef(reasonIfUnsupported,
432 input.GetDataType(),
433 &TrueFunc<>,
434 &TrueFunc<>);
435}
436
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000437bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
438 const TensorInfo& output,
439 const SpaceToBatchNdDescriptor& descriptor,
440 Optional<std::string&> reasonIfUnsupported) const
441{
442 ignore_unused(output);
443 ignore_unused(descriptor);
444 return IsSupportedForDataTypeRef(reasonIfUnsupported,
445 input.GetDataType(),
446 &TrueFunc<>,
447 &TrueFunc<>);
448}
449
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100450bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
451 const ViewsDescriptor& descriptor,
452 Optional<std::string&> reasonIfUnsupported) const
453{
454 ignore_unused(descriptor);
455 return IsSupportedForDataTypeRef(reasonIfUnsupported,
456 input.GetDataType(),
457 &TrueFunc<>,
458 &TrueFunc<>);
459}
460
461bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
462 const TensorInfo& input1,
463 const TensorInfo& output,
464 Optional<std::string&> reasonIfUnsupported) const
465{
466 ignore_unused(input1);
467 ignore_unused(output);
468 return IsSupportedForDataTypeRef(reasonIfUnsupported,
469 input0.GetDataType(),
470 &TrueFunc<>,
471 &TrueFunc<>);
472}
473
arovir011c7c81b2018-10-08 11:34:28 +0100474} // namespace armnn