blob: 91be98182a24362b0a134c1369b21a199c4de4c3 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include <armnn/DescriptorsFwd.hpp>
arovir01085f0a42018-10-08 14:48:19 +01008#include <armnn/Optional.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
10#include <armnn/Tensor.hpp>
11
12namespace armnn
13{
14
arovir014424b0a2018-10-04 10:46:04 +010015class NeonLayerSupport : public ILayerSupport
16{
arovir017ff76c52018-10-09 09:40:58 +010017public:
18 bool IsActivationSupported(const TensorInfo& input,
19 const TensorInfo& output,
20 const ActivationDescriptor& descriptor,
21 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
22
23 bool IsAdditionSupported(const TensorInfo& input0,
24 const TensorInfo& input1,
25 const TensorInfo& output,
26 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
27
28 bool IsBatchNormalizationSupported(const TensorInfo& input,
29 const TensorInfo& output,
30 const TensorInfo& mean,
31 const TensorInfo& var,
32 const TensorInfo& beta,
33 const TensorInfo& gamma,
34 const BatchNormalizationDescriptor& descriptor,
35 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
36
37 bool IsConstantSupported(const TensorInfo& output,
38 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
39
40 bool IsConvertFp16ToFp32Supported(const TensorInfo& input,
41 const TensorInfo& output,
42 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
43
44 bool IsConvertFp32ToFp16Supported(const TensorInfo& input,
45 const TensorInfo& output,
46 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
47
48 bool IsConvolution2dSupported(const TensorInfo& input,
49 const TensorInfo& output,
50 const Convolution2dDescriptor& descriptor,
51 const TensorInfo& weights,
52 const Optional<TensorInfo>& biases,
53 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
54
55 bool IsDepthwiseConvolutionSupported(const TensorInfo& input,
56 const TensorInfo& output,
57 const DepthwiseConvolution2dDescriptor& descriptor,
58 const TensorInfo& weights,
59 const Optional<TensorInfo>& biases,
60 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
61
62 bool IsDivisionSupported(const TensorInfo& input0,
63 const TensorInfo& input1,
64 const TensorInfo& output,
65 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
66
67 bool IsFakeQuantizationSupported(const TensorInfo& input,
68 const FakeQuantizationDescriptor& descriptor,
69 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
70
71 bool IsFloorSupported(const TensorInfo& input,
72 const TensorInfo& output,
73 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
74
75 bool IsFullyConnectedSupported(const TensorInfo& input,
76 const TensorInfo& output,
77 const TensorInfo& weights,
78 const TensorInfo& biases,
79 const FullyConnectedDescriptor& descriptor,
80 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
81
82 bool IsInputSupported(const TensorInfo& input,
83 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
84
85 bool IsL2NormalizationSupported(const TensorInfo& input,
86 const TensorInfo& output,
87 const L2NormalizationDescriptor& descriptor,
88 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
89
90 bool IsLstmSupported(const TensorInfo& input,
91 const TensorInfo& outputStateIn,
92 const TensorInfo& cellStateIn,
93 const TensorInfo& scratchBuffer,
94 const TensorInfo& outputStateOut,
95 const TensorInfo& cellStateOut,
96 const TensorInfo& output,
97 const LstmDescriptor& descriptor,
98 const TensorInfo& inputToForgetWeights,
99 const TensorInfo& inputToCellWeights,
100 const TensorInfo& inputToOutputWeights,
101 const TensorInfo& recurrentToForgetWeights,
102 const TensorInfo& recurrentToCellWeights,
103 const TensorInfo& recurrentToOutputWeights,
104 const TensorInfo& forgetGateBias,
105 const TensorInfo& cellBias,
106 const TensorInfo& outputGateBias,
107 const TensorInfo* inputToInputWeights,
108 const TensorInfo* recurrentToInputWeights,
109 const TensorInfo* cellToInputWeights,
110 const TensorInfo* inputGateBias,
111 const TensorInfo* projectionWeights,
112 const TensorInfo* projectionBias,
113 const TensorInfo* cellToForgetWeights,
114 const TensorInfo* cellToOutputWeights,
115 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
116
117 bool IsMeanSupported(const TensorInfo& input,
118 const TensorInfo& output,
119 const MeanDescriptor& descriptor,
120 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
121
122 bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
123 const OriginsDescriptor& descriptor,
124 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
125
126 bool IsMultiplicationSupported(const TensorInfo& input0,
127 const TensorInfo& input1,
128 const TensorInfo& output,
129 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
130
131 bool IsNormalizationSupported(const TensorInfo& input,
132 const TensorInfo& output,
133 const NormalizationDescriptor& descriptor,
134 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
135
136 bool IsOutputSupported(const TensorInfo& output,
137 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
138
139 bool IsPadSupported(const TensorInfo& input,
140 const TensorInfo& output,
141 const PadDescriptor& descriptor,
142 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
143
144 bool IsPermuteSupported(const TensorInfo& input,
145 const TensorInfo& output,
146 const PermuteDescriptor& descriptor,
147 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
148
149 bool IsPooling2dSupported(const TensorInfo& input,
150 const TensorInfo& output,
151 const Pooling2dDescriptor& descriptor,
152 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
153
154 bool IsReshapeSupported(const TensorInfo& input,
155 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
156
157 bool IsResizeBilinearSupported(const TensorInfo& input,
158 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
159
160 bool IsSoftmaxSupported(const TensorInfo& input,
161 const TensorInfo& output,
162 const SoftmaxDescriptor& descriptor,
163 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
164
165 bool IsSplitterSupported(const TensorInfo& input,
166 const ViewsDescriptor& descriptor,
167 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
168
169 bool IsSubtractionSupported(const TensorInfo& input0,
170 const TensorInfo& input1,
171 const TensorInfo& output,
172 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
arovir014424b0a2018-10-04 10:46:04 +0100173};
174
telsoa014fcda012018-03-09 14:13:49 +0000175bool IsNeonDirectConvolutionPreferred(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc);
176
arovir01085f0a42018-10-08 14:48:19 +0100177bool IsNeonNormalizationDescParamsSupported(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000178 const NormalizationDescriptor& parameters);
179
180bool IsActivationSupportedNeon(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100181 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000182 const ActivationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100183 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000184
arovir01085f0a42018-10-08 14:48:19 +0100185bool IsNeonDepthwiseConvolution2dDescParamsSupported(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000186 const DepthwiseConvolution2dDescriptor& parameters,
187 const TensorInfo& weights);
188
189bool IsAdditionSupportedNeon(const TensorInfo& input0,
190 const TensorInfo& input1,
191 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100192 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000193
194bool IsBatchNormalizationSupportedNeon(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100195 const TensorInfo& output,
196 const TensorInfo& mean,
197 const TensorInfo& var,
198 const TensorInfo& beta,
199 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +0000200 const BatchNormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100201 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000202
203bool IsConstantSupportedNeon(const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100204 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000205
206bool IsConvolution2dSupportedNeon(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +0100207 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000208 const Convolution2dDescriptor& descriptor,
209 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100210 const Optional<TensorInfo>& biases,
arovir01085f0a42018-10-08 14:48:19 +0100211 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000212
telsoa01c577f2c2018-08-31 09:22:23 +0100213
telsoa014fcda012018-03-09 14:13:49 +0000214bool IsDepthwiseConvolutionSupportedNeon(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100215 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000216 const DepthwiseConvolution2dDescriptor& descriptor,
217 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100218 const Optional<TensorInfo>& biases,
arovir01085f0a42018-10-08 14:48:19 +0100219 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000220
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100221bool IsDivisionSupportedNeon(const TensorInfo& input0,
222 const TensorInfo& input1,
223 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100224 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100225
David Beckc2044fe2018-09-05 15:00:38 +0100226bool IsSubtractionSupportedNeon(const TensorInfo& input0,
227 const TensorInfo& input1,
228 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100229 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
David Beckc2044fe2018-09-05 15:00:38 +0100230
telsoa014fcda012018-03-09 14:13:49 +0000231bool IsFullyConnectedSupportedNeon(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100232 const TensorInfo& output,
233 const TensorInfo& weights,
234 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000235 const FullyConnectedDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100236 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000237
238bool IsInputSupportedNeon(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100239 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000240
241bool IsL2NormalizationSupportedNeon(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100242 const TensorInfo& output,
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100243 const L2NormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100244 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000245
246bool IsMergerSupportedNeon(const std::vector<const TensorInfo*> inputs,
247 const OriginsDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100248 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000249
250bool IsMultiplicationSupportedNeon(const TensorInfo& input0,
251 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100252 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100253 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000254
255bool IsNormalizationSupportedNeon(const TensorInfo& input,
256 const TensorInfo& output,
257 const NormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100258 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000259
260bool IsOutputSupportedNeon(const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100261 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000262
263bool IsPermuteSupportedNeon(const TensorInfo& input,
264 const TensorInfo& output,
265 const PermuteDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100266 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000267
268bool IsPooling2dSupportedNeon(const TensorInfo& input,
269 const TensorInfo& output,
270 const Pooling2dDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100271 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000272
273bool IsResizeBilinearSupportedNeon(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100274 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000275
276bool IsSoftmaxSupportedNeon(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100277 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000278 const SoftmaxDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100279 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000280
281bool IsSplitterSupportedNeon(const TensorInfo& input,
282 const ViewsDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100283 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000284
285bool IsFakeQuantizationSupportedNeon(const TensorInfo& input,
286 const FakeQuantizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100287 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000288
289bool IsReshapeSupportedNeon(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100290 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000291
292bool IsFloorSupportedNeon(const TensorInfo& input,
293 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100294 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa014fcda012018-03-09 14:13:49 +0000295
arovir01085f0a42018-10-08 14:48:19 +0100296bool IsLstmSupportedNeon(const TensorInfo& input,
297 const TensorInfo& outputStateIn,
298 const TensorInfo& cellStateIn,
299 const TensorInfo& scratchBuffer,
300 const TensorInfo& outputStateOut,
301 const TensorInfo& cellStateOut,
302 const TensorInfo& output,
303 const LstmDescriptor& descriptor,
304 const TensorInfo& inputToForgetWeights,
305 const TensorInfo& inputToCellWeights,
306 const TensorInfo& inputToOutputWeights,
307 const TensorInfo& recurrentToForgetWeights,
308 const TensorInfo& recurrentToCellWeights,
309 const TensorInfo& recurrentToOutputWeights,
310 const TensorInfo& forgetGateBias,
311 const TensorInfo& cellBias,
312 const TensorInfo& outputGateBias,
313 const TensorInfo* inputToInputWeights,
314 const TensorInfo* recurrentToInputWeights,
315 const TensorInfo* cellToInputWeights,
316 const TensorInfo* inputGateBias,
317 const TensorInfo* projectionWeights,
318 const TensorInfo* projectionBias,
319 const TensorInfo* cellToForgetWeights,
320 const TensorInfo* cellToOutputWeights,
321 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa01c577f2c2018-08-31 09:22:23 +0100322
323bool IsConvertFp16ToFp32SupportedNeon(const TensorInfo& input,
324 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100325 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa01c577f2c2018-08-31 09:22:23 +0100326
327bool IsConvertFp32ToFp16SupportedNeon(const TensorInfo& input,
328 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100329 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
telsoa01c577f2c2018-08-31 09:22:23 +0100330
narpra0132b90462018-09-13 11:07:48 +0100331bool IsMeanSupportedNeon(const TensorInfo& input,
332 const TensorInfo& output,
333 const MeanDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100334 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
narpra0132b90462018-09-13 11:07:48 +0100335
Nina Drozd661dfa72018-10-02 11:14:17 +0100336bool IsPadSupportedNeon(const TensorInfo& input,
337 const TensorInfo& output,
338 const PadDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100339 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
Nina Drozd661dfa72018-10-02 11:14:17 +0100340
arovir017ff76c52018-10-09 09:40:58 +0100341} // namespace armnn