blob: ca4fca6f318617738857dda051892b5c882c86db [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include "LayerSupportCommon.hpp"
7#include "RefLayerSupport.hpp"
8#include <armnn/Descriptors.hpp>
9#include <armnn/Types.hpp>
10#include <armnn/Tensor.hpp>
11
12#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013#include "InternalTypes.hpp"
14
15using namespace boost;
16
17namespace armnn
18{
19
20template<typename Float32Func, typename Uint8Func, typename ... Params>
21bool IsSupportedForDataTypeRef(std::string* reasonIfUnsupported,
22 DataType dataType,
23 Float32Func floatFuncPtr,
24 Uint8Func uint8FuncPtr,
25 Params&&... params)
26{
27 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
28 dataType,
telsoa01c577f2c2018-08-31 09:22:23 +010029 &FalseFunc<Params...>,
telsoa014fcda012018-03-09 14:13:49 +000030 floatFuncPtr,
31 uint8FuncPtr,
32 std::forward<Params>(params)...);
33}
34
35bool IsActivationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +010036 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +000037 const ActivationDescriptor& descriptor,
38 std::string* reasonIfUnsupported)
39{
telsoa01c577f2c2018-08-31 09:22:23 +010040 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +000041 ignore_unused(descriptor);
42 return IsSupportedForDataTypeRef(reasonIfUnsupported,
43 input.GetDataType(),
44 &TrueFunc<>,
45 &TrueFunc<>);
46}
47
48bool IsAdditionSupportedRef(const TensorInfo& input0,
49 const TensorInfo& input1,
50 const TensorInfo& output,
51 std::string* reasonIfUnsupported)
52{
53 ignore_unused(input1);
54 ignore_unused(output);
55 return IsSupportedForDataTypeRef(reasonIfUnsupported,
56 input0.GetDataType(),
57 &TrueFunc<>,
58 &TrueFunc<>);
59}
60
61bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +010062 const TensorInfo& output,
63 const TensorInfo& mean,
64 const TensorInfo& var,
65 const TensorInfo& beta,
66 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +000067 const BatchNormalizationDescriptor& descriptor,
68 std::string* reasonIfUnsupported)
69{
70 ignore_unused(descriptor);
71 return IsSupportedForDataTypeRef(reasonIfUnsupported,
72 input.GetDataType(),
73 &TrueFunc<>,
74 &TrueFunc<>);
75}
76
77bool IsConstantSupportedRef(const TensorInfo& output,
78 std::string* reasonIfUnsupported)
79{
80 return IsSupportedForDataTypeRef(reasonIfUnsupported,
81 output.GetDataType(),
82 &TrueFunc<>,
83 &TrueFunc<>);
84}
85
86bool IsConvolution2dSupportedRef(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +010087 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +000088 const Convolution2dDescriptor& descriptor,
89 const TensorInfo& weights,
surmeh013537c2c2018-05-18 16:31:43 +010090 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +000091 std::string* reasonIfUnsupported)
92{
93 ignore_unused(descriptor);
surmeh013537c2c2018-05-18 16:31:43 +010094 ignore_unused(output);
95 ignore_unused(weights);
96 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +000097 return IsSupportedForDataTypeRef(reasonIfUnsupported,
98 input.GetDataType(),
99 &TrueFunc<>,
100 &TrueFunc<>);
101}
102
103bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100104 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000105 const DepthwiseConvolution2dDescriptor& descriptor,
106 const TensorInfo& weights,
telsoa01c577f2c2018-08-31 09:22:23 +0100107 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000108 std::string* reasonIfUnsupported)
109{
telsoa01c577f2c2018-08-31 09:22:23 +0100110 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000111 ignore_unused(descriptor);
112 ignore_unused(weights);
telsoa01c577f2c2018-08-31 09:22:23 +0100113 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000114 return IsSupportedForDataTypeRef(reasonIfUnsupported,
115 input.GetDataType(),
116 &TrueFunc<>,
117 &TrueFunc<>);
118}
119
120bool IsFullyConnectedSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100121 const TensorInfo& output,
122 const TensorInfo& weights,
123 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000124 const FullyConnectedDescriptor& descriptor,
125 std::string* reasonIfUnsupported)
126{
telsoa01c577f2c2018-08-31 09:22:23 +0100127 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000128 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100129 ignore_unused(weights);
130 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000131 return IsSupportedForDataTypeRef(reasonIfUnsupported,
132 input.GetDataType(),
133 &TrueFunc<>,
134 &TrueFunc<>);
135}
136
137bool IsInputSupportedRef(const TensorInfo& input,
138 std::string* reasonIfUnsupported)
139{
140 return IsSupportedForDataTypeRef(reasonIfUnsupported,
141 input.GetDataType(),
142 &TrueFunc<>,
143 &TrueFunc<>);
144}
145
146bool IsL2NormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100147 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000148 std::string* reasonIfUnsupported)
149{
telsoa01c577f2c2018-08-31 09:22:23 +0100150 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000151 return IsSupportedForDataTypeRef(reasonIfUnsupported,
152 input.GetDataType(),
153 &TrueFunc<>,
154 &FalseFuncU8<>);
155}
156
157bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
158 const OriginsDescriptor& descriptor,
159 std::string* reasonIfUnsupported)
160{
161 ignore_unused(descriptor);
162 return IsSupportedForDataTypeRef(reasonIfUnsupported,
163 inputs[0]->GetDataType(),
164 &TrueFunc<>,
165 &TrueFunc<>);
166}
167
168bool IsMultiplicationSupportedRef(const TensorInfo& input0,
169 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100170 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000171 std::string* reasonIfUnsupported)
172{
173 ignore_unused(input1);
telsoa01c577f2c2018-08-31 09:22:23 +0100174 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000175 return IsSupportedForDataTypeRef(reasonIfUnsupported,
176 input0.GetDataType(),
177 &TrueFunc<>,
178 &TrueFunc<>);
179}
180
181bool IsNormalizationSupportedRef(const TensorInfo& input,
182 const TensorInfo& output,
183 const NormalizationDescriptor& descriptor,
184 std::string* reasonIfUnsupported)
185{
186 ignore_unused(descriptor);
187 return IsSupportedForDataTypeRef(reasonIfUnsupported,
188 input.GetDataType(),
189 &TrueFunc<>,
190 &FalseFuncU8<>);
191}
192
193bool IsOutputSupportedRef(const TensorInfo& output,
194 std::string* reasonIfUnsupported)
195{
196 return IsSupportedForDataTypeRef(reasonIfUnsupported,
197 output.GetDataType(),
198 &TrueFunc<>,
199 &TrueFunc<>);
200}
201
202bool IsPermuteSupportedRef(const TensorInfo& input,
203 const TensorInfo& output,
204 const PermuteDescriptor& descriptor,
205 std::string* reasonIfUnsupported)
206{
207 ignore_unused(descriptor);
208 return IsSupportedForDataTypeRef(reasonIfUnsupported,
209 input.GetDataType(),
210 &TrueFunc<>,
211 &TrueFunc<>);
212}
213
214bool IsPooling2dSupportedRef(const TensorInfo& input,
215 const TensorInfo& output,
216 const Pooling2dDescriptor& descriptor,
217 std::string* reasonIfUnsupported)
218{
219 ignore_unused(descriptor);
220 return IsSupportedForDataTypeRef(reasonIfUnsupported,
221 input.GetDataType(),
222 &TrueFunc<>,
223 &TrueFunc<>);
224}
225
226bool IsResizeBilinearSupportedRef(const TensorInfo& input,
227 std::string* reasonIfUnsupported)
228{
229 return IsSupportedForDataTypeRef(reasonIfUnsupported,
230 input.GetDataType(),
231 &TrueFunc<>,
232 &TrueFunc<>);
233}
234
235bool IsSoftmaxSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100236 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000237 const SoftmaxDescriptor& descriptor,
238 std::string* reasonIfUnsupported)
239{
telsoa01c577f2c2018-08-31 09:22:23 +0100240 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000241 ignore_unused(descriptor);
242 return IsSupportedForDataTypeRef(reasonIfUnsupported,
243 input.GetDataType(),
244 &TrueFunc<>,
245 &TrueFunc<>);
246}
247
248bool IsSplitterSupportedRef(const TensorInfo& input,
249 const ViewsDescriptor& descriptor,
250 std::string* reasonIfUnsupported)
251{
252 ignore_unused(descriptor);
253 return IsSupportedForDataTypeRef(reasonIfUnsupported,
254 input.GetDataType(),
255 &TrueFunc<>,
256 &TrueFunc<>);
257}
258
259bool IsFakeQuantizationSupportedRef(const TensorInfo& input,
260 const FakeQuantizationDescriptor& descriptor,
261 std::string* reasonIfUnsupported)
262{
263 ignore_unused(descriptor);
264 return IsSupportedForDataTypeRef(reasonIfUnsupported,
265 input.GetDataType(),
266 &TrueFunc<>,
267 &FalseFuncU8<>);
268}
269
270bool IsReshapeSupportedRef(const TensorInfo& input,
271 std::string* reasonIfUnsupported)
272{
273 return IsSupportedForDataTypeRef(reasonIfUnsupported,
274 input.GetDataType(),
275 &TrueFunc<>,
276 &TrueFunc<>);
277}
278
279bool IsFloorSupportedRef(const TensorInfo& input,
280 const TensorInfo& output,
281 std::string* reasonIfUnsupported)
282{
283 ignore_unused(output);
284 return IsSupportedForDataTypeRef(reasonIfUnsupported,
285 input.GetDataType(),
286 &TrueFunc<>,
287 &FalseFuncU8<>);
288}
289
telsoa01c577f2c2018-08-31 09:22:23 +0100290bool IsLstmSupportedRef(const TensorInfo& input, const TensorInfo& outputStateIn,
291 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
292 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
293 const TensorInfo& output, const LstmDescriptor& descriptor,
294 const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
295 const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
296 const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
297 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
298 const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
299 const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
300 const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
301 const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
302 const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
303{
304 ignore_unused(input);
305 ignore_unused(outputStateIn);
306 ignore_unused(cellStateIn);
307 ignore_unused(scratchBuffer);
308 ignore_unused(outputStateOut);
309 ignore_unused(cellStateOut);
310 ignore_unused(output);
311 ignore_unused(descriptor);
312 ignore_unused(inputToForgetWeights);
313 ignore_unused(inputToCellWeights);
314 ignore_unused(inputToOutputWeights);
315 ignore_unused(recurrentToForgetWeights);
316 ignore_unused(recurrentToCellWeights);
317 ignore_unused(recurrentToOutputWeights);
318 ignore_unused(forgetGateBias);
319 ignore_unused(cellBias);
320 ignore_unused(outputGateBias);
321 ignore_unused(inputToInputWeights);
322 ignore_unused(recurrentToInputWeights);
323 ignore_unused(cellToInputWeights);
324 ignore_unused(inputGateBias);
325 ignore_unused(projectionWeights);
326 ignore_unused(projectionBias);
327 ignore_unused(cellToForgetWeights);
328 ignore_unused(cellToOutputWeights);
329 return false;
330}
331
332bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input,
333 const TensorInfo& output,
334 std::string* reasonIfUnsupported)
335{
336 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
337 input.GetDataType(),
338 &TrueFunc<>,
339 &FalseInputFuncF32<>,
340 &FalseFuncU8<>) &&
341 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
342 output.GetDataType(),
343 &FalseOutputFuncF16<>,
344 &TrueFunc<>,
345 &FalseFuncU8<>));
346}
347
348bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input,
349 const TensorInfo& output,
350 std::string* reasonIfUnsupported)
351{
352 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
353 input.GetDataType(),
354 &FalseInputFuncF16<>,
355 &TrueFunc<>,
356 &FalseFuncU8<>) &&
357 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
358 output.GetDataType(),
359 &TrueFunc<>,
360 &FalseOutputFuncF32<>,
361 &FalseFuncU8<>));
362}
363
telsoa014fcda012018-03-09 14:13:49 +0000364}