blob: 909df7544508b894f0127c9173deb359bb986857 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01008#include <armnn/InternalTypes.hpp>
9#include <armnn/LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010#include <armnn/Types.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
12#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
14using namespace boost;
15
16namespace armnn
17{
18
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010019namespace
20{
21
22template<typename Float32Func, typename Uint8Func, typename ... Params>
23bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
24 DataType dataType,
25 Float32Func floatFuncPtr,
26 Uint8Func uint8FuncPtr,
27 Params&&... params)
28{
29 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
30 dataType,
31 &FalseFunc<Params...>,
32 floatFuncPtr,
33 uint8FuncPtr,
34 std::forward<Params>(params)...);
35}
36
37} // anonymous namespace
38
arovir011c7c81b2018-10-08 11:34:28 +010039bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
40 const TensorInfo& output,
41 const ActivationDescriptor& descriptor,
42 Optional<std::string&> reasonIfUnsupported) const
43{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010044 ignore_unused(output);
45 ignore_unused(descriptor);
46 return IsSupportedForDataTypeRef(reasonIfUnsupported,
47 input.GetDataType(),
48 &TrueFunc<>,
49 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010050}
51
52bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
53 const TensorInfo& input1,
54 const TensorInfo& output,
55 Optional<std::string&> reasonIfUnsupported) const
56{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010057 ignore_unused(input1);
58 ignore_unused(output);
59 return IsSupportedForDataTypeRef(reasonIfUnsupported,
60 input0.GetDataType(),
61 &TrueFunc<>,
62 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010063}
64
65bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
66 const TensorInfo& output,
67 const TensorInfo& mean,
68 const TensorInfo& var,
69 const TensorInfo& beta,
70 const TensorInfo& gamma,
71 const BatchNormalizationDescriptor& descriptor,
72 Optional<std::string&> reasonIfUnsupported) const
73{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010074 ignore_unused(output);
75 ignore_unused(mean);
76 ignore_unused(var);
77 ignore_unused(beta);
78 ignore_unused(gamma);
79 ignore_unused(descriptor);
80 return IsSupportedForDataTypeRef(reasonIfUnsupported,
81 input.GetDataType(),
82 &TrueFunc<>,
83 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010084}
85
86bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
87 Optional<std::string&> reasonIfUnsupported) const
88{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010089 return IsSupportedForDataTypeRef(reasonIfUnsupported,
90 output.GetDataType(),
91 &TrueFunc<>,
92 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +010093}
94
95bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
96 const TensorInfo& output,
97 Optional<std::string&> reasonIfUnsupported) const
98{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010099 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
100 input.GetDataType(),
101 &TrueFunc<>,
102 &FalseInputFuncF32<>,
103 &FalseFuncU8<>) &&
104 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
105 output.GetDataType(),
106 &FalseOutputFuncF16<>,
107 &TrueFunc<>,
108 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100109}
110
111bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
112 const TensorInfo& output,
113 Optional<std::string&> reasonIfUnsupported) const
114{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100115 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
116 input.GetDataType(),
117 &FalseInputFuncF16<>,
118 &TrueFunc<>,
119 &FalseFuncU8<>) &&
120 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
121 output.GetDataType(),
122 &TrueFunc<>,
123 &FalseOutputFuncF32<>,
124 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100125}
126
127bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
128 const TensorInfo& output,
129 const Convolution2dDescriptor& descriptor,
130 const TensorInfo& weights,
131 const Optional<TensorInfo>& biases,
132 Optional<std::string&> reasonIfUnsupported) const
133{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100134 ignore_unused(output);
135 ignore_unused(descriptor);
136 ignore_unused(weights);
137 ignore_unused(biases);
138 return IsSupportedForDataTypeRef(reasonIfUnsupported,
139 input.GetDataType(),
140 &TrueFunc<>,
141 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100142}
143
144bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
145 const TensorInfo& output,
146 const DepthwiseConvolution2dDescriptor& descriptor,
147 const TensorInfo& weights,
148 const Optional<TensorInfo>& biases,
149 Optional<std::string&> reasonIfUnsupported) const
150{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100151 ignore_unused(output);
152 ignore_unused(descriptor);
153 ignore_unused(weights);
154 ignore_unused(biases);
155 return IsSupportedForDataTypeRef(reasonIfUnsupported,
156 input.GetDataType(),
157 &TrueFunc<>,
158 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100159}
160
161bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
162 const TensorInfo& input1,
163 const TensorInfo& output,
164 Optional<std::string&> reasonIfUnsupported) const
165{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100166 ignore_unused(input1);
167 ignore_unused(output);
168 return IsSupportedForDataTypeRef(reasonIfUnsupported,
169 input0.GetDataType(),
170 &TrueFunc<>,
171 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100172}
173
174bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
175 const FakeQuantizationDescriptor& descriptor,
176 Optional<std::string&> reasonIfUnsupported) const
177{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100178 ignore_unused(descriptor);
179 return IsSupportedForDataTypeRef(reasonIfUnsupported,
180 input.GetDataType(),
181 &TrueFunc<>,
182 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100183}
184
185bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
186 const TensorInfo& output,
187 Optional<std::string&> reasonIfUnsupported) const
188{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100189 ignore_unused(output);
190 return IsSupportedForDataTypeRef(reasonIfUnsupported,
191 input.GetDataType(),
192 &TrueFunc<>,
193 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100194}
195
196bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
197 const TensorInfo& output,
198 const TensorInfo& weights,
199 const TensorInfo& biases,
200 const FullyConnectedDescriptor& descriptor,
201 Optional<std::string&> reasonIfUnsupported) const
202{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100203 ignore_unused(output);
204 ignore_unused(weights);
205 ignore_unused(biases);
206 ignore_unused(descriptor);
207 return IsSupportedForDataTypeRef(reasonIfUnsupported,
208 input.GetDataType(),
209 &TrueFunc<>,
210 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100211}
212
213bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
214 Optional<std::string&> reasonIfUnsupported) const
215{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100216 return IsSupportedForDataTypeRef(reasonIfUnsupported,
217 input.GetDataType(),
218 &TrueFunc<>,
219 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100220}
221
222bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
223 const TensorInfo& output,
224 const L2NormalizationDescriptor& descriptor,
225 Optional<std::string&> reasonIfUnsupported) const
226{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100227 ignore_unused(output);
228 ignore_unused(descriptor);
229 return IsSupportedForDataTypeRef(reasonIfUnsupported,
230 input.GetDataType(),
231 &TrueFunc<>,
232 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100233}
234
235bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
236 const TensorInfo& outputStateIn,
237 const TensorInfo& cellStateIn,
238 const TensorInfo& scratchBuffer,
239 const TensorInfo& outputStateOut,
240 const TensorInfo& cellStateOut,
241 const TensorInfo& output,
242 const LstmDescriptor& descriptor,
243 const TensorInfo& inputToForgetWeights,
244 const TensorInfo& inputToCellWeights,
245 const TensorInfo& inputToOutputWeights,
246 const TensorInfo& recurrentToForgetWeights,
247 const TensorInfo& recurrentToCellWeights,
248 const TensorInfo& recurrentToOutputWeights,
249 const TensorInfo& forgetGateBias,
250 const TensorInfo& cellBias,
251 const TensorInfo& outputGateBias,
252 const TensorInfo* inputToInputWeights,
253 const TensorInfo* recurrentToInputWeights,
254 const TensorInfo* cellToInputWeights,
255 const TensorInfo* inputGateBias,
256 const TensorInfo* projectionWeights,
257 const TensorInfo* projectionBias,
258 const TensorInfo* cellToForgetWeights,
259 const TensorInfo* cellToOutputWeights,
260 Optional<std::string&> reasonIfUnsupported) const
261{
telsoa01c577f2c2018-08-31 09:22:23 +0100262 ignore_unused(input);
263 ignore_unused(outputStateIn);
264 ignore_unused(cellStateIn);
265 ignore_unused(scratchBuffer);
266 ignore_unused(outputStateOut);
267 ignore_unused(cellStateOut);
268 ignore_unused(output);
269 ignore_unused(descriptor);
270 ignore_unused(inputToForgetWeights);
271 ignore_unused(inputToCellWeights);
272 ignore_unused(inputToOutputWeights);
273 ignore_unused(recurrentToForgetWeights);
274 ignore_unused(recurrentToCellWeights);
275 ignore_unused(recurrentToOutputWeights);
276 ignore_unused(forgetGateBias);
277 ignore_unused(cellBias);
278 ignore_unused(outputGateBias);
279 ignore_unused(inputToInputWeights);
280 ignore_unused(recurrentToInputWeights);
281 ignore_unused(cellToInputWeights);
282 ignore_unused(inputGateBias);
283 ignore_unused(projectionWeights);
284 ignore_unused(projectionBias);
285 ignore_unused(cellToForgetWeights);
286 ignore_unused(cellToOutputWeights);
arovir01085f0a42018-10-08 14:48:19 +0100287 ignore_unused(reasonIfUnsupported);
telsoa01c577f2c2018-08-31 09:22:23 +0100288 return false;
289}
290
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100291bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
292 const TensorInfo& output,
293 const MeanDescriptor& descriptor,
294 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100295{
narpra011e4c31d2018-09-28 11:07:51 +0100296 ignore_unused(output);
297 ignore_unused(descriptor);
298 return IsSupportedForDataTypeRef(reasonIfUnsupported,
299 input.GetDataType(),
300 &TrueFunc<>,
301 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100302}
303
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100304bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
305 const OriginsDescriptor& descriptor,
306 Optional<std::string&> reasonIfUnsupported) const
307{
308 ignore_unused(descriptor);
309 return IsSupportedForDataTypeRef(reasonIfUnsupported,
310 inputs[0]->GetDataType(),
311 &TrueFunc<>,
312 &TrueFunc<>);
313}
314
315bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
316 const TensorInfo& input1,
317 const TensorInfo& output,
318 Optional<std::string&> reasonIfUnsupported) const
319{
320 ignore_unused(input1);
321 ignore_unused(output);
322 return IsSupportedForDataTypeRef(reasonIfUnsupported,
323 input0.GetDataType(),
324 &TrueFunc<>,
325 &TrueFunc<>);
326}
327
328bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
329 const TensorInfo& output,
330 const NormalizationDescriptor& descriptor,
331 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100332{
333 ignore_unused(output);
334 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100335 return IsSupportedForDataTypeRef(reasonIfUnsupported,
336 input.GetDataType(),
337 &TrueFunc<>,
338 &FalseFuncU8<>);
339}
340
341bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
342 Optional<std::string&> reasonIfUnsupported) const
343{
344 return IsSupportedForDataTypeRef(reasonIfUnsupported,
345 output.GetDataType(),
346 &TrueFunc<>,
347 &TrueFunc<>);
348}
349
350bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
351 const TensorInfo& output,
352 const PadDescriptor& descriptor,
353 Optional<std::string&> reasonIfUnsupported) const
354{
355 ignore_unused(input);
356 ignore_unused(output);
357 ignore_unused(descriptor);
358 ignore_unused(reasonIfUnsupported);
Nina Drozd661dfa72018-10-02 11:14:17 +0100359 return false;
360}
361
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100362bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
363 const TensorInfo& output,
364 const PermuteDescriptor& descriptor,
365 Optional<std::string&> reasonIfUnsupported) const
366{
367 ignore_unused(output);
368 ignore_unused(descriptor);
369 return IsSupportedForDataTypeRef(reasonIfUnsupported,
370 input.GetDataType(),
371 &TrueFunc<>,
372 &TrueFunc<>);
373}
374
375bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
376 const TensorInfo& output,
377 const Pooling2dDescriptor& descriptor,
378 Optional<std::string&> reasonIfUnsupported) const
379{
380 ignore_unused(output);
381 ignore_unused(descriptor);
382 return IsSupportedForDataTypeRef(reasonIfUnsupported,
383 input.GetDataType(),
384 &TrueFunc<>,
385 &TrueFunc<>);
386}
387
388bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
389 Optional<std::string&> reasonIfUnsupported) const
390{
391 return IsSupportedForDataTypeRef(reasonIfUnsupported,
392 input.GetDataType(),
393 &TrueFunc<>,
394 &TrueFunc<>);
395}
396
397bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
398 Optional<std::string&> reasonIfUnsupported) const
399{
400 return IsSupportedForDataTypeRef(reasonIfUnsupported,
401 input.GetDataType(),
402 &TrueFunc<>,
403 &TrueFunc<>);
404}
405
406bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
407 const TensorInfo& output,
408 const SoftmaxDescriptor& descriptor,
409 Optional<std::string&> reasonIfUnsupported) const
410{
411 ignore_unused(output);
412 ignore_unused(descriptor);
413 return IsSupportedForDataTypeRef(reasonIfUnsupported,
414 input.GetDataType(),
415 &TrueFunc<>,
416 &TrueFunc<>);
417}
418
419bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
420 const ViewsDescriptor& descriptor,
421 Optional<std::string&> reasonIfUnsupported) const
422{
423 ignore_unused(descriptor);
424 return IsSupportedForDataTypeRef(reasonIfUnsupported,
425 input.GetDataType(),
426 &TrueFunc<>,
427 &TrueFunc<>);
428}
429
430bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
431 const TensorInfo& input1,
432 const TensorInfo& output,
433 Optional<std::string&> reasonIfUnsupported) const
434{
435 ignore_unused(input1);
436 ignore_unused(output);
437 return IsSupportedForDataTypeRef(reasonIfUnsupported,
438 input0.GetDataType(),
439 &TrueFunc<>,
440 &TrueFunc<>);
441}
442
arovir011c7c81b2018-10-08 11:34:28 +0100443} // namespace armnn