blob: d56cdebeda0bd140ebb8490c7e5b0175459c704c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "LayerSupportCommon.hpp"
7#include "RefLayerSupport.hpp"
8#include <armnn/Descriptors.hpp>
9#include <armnn/Types.hpp>
10#include <armnn/Tensor.hpp>
11
12#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013#include "InternalTypes.hpp"
14
15using namespace boost;
16
17namespace armnn
18{
19
20template<typename Float32Func, typename Uint8Func, typename ... Params>
21bool IsSupportedForDataTypeRef(std::string* reasonIfUnsupported,
22 DataType dataType,
23 Float32Func floatFuncPtr,
24 Uint8Func uint8FuncPtr,
25 Params&&... params)
26{
27 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
28 dataType,
telsoa01c577f2c2018-08-31 09:22:23 +010029 &FalseFunc<Params...>,
telsoa014fcda012018-03-09 14:13:49 +000030 floatFuncPtr,
31 uint8FuncPtr,
32 std::forward<Params>(params)...);
33}
34
35bool IsActivationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +010036 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +000037 const ActivationDescriptor& descriptor,
38 std::string* reasonIfUnsupported)
39{
telsoa01c577f2c2018-08-31 09:22:23 +010040 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +000041 ignore_unused(descriptor);
42 return IsSupportedForDataTypeRef(reasonIfUnsupported,
43 input.GetDataType(),
44 &TrueFunc<>,
45 &TrueFunc<>);
46}
47
48bool IsAdditionSupportedRef(const TensorInfo& input0,
49 const TensorInfo& input1,
50 const TensorInfo& output,
51 std::string* reasonIfUnsupported)
52{
53 ignore_unused(input1);
54 ignore_unused(output);
55 return IsSupportedForDataTypeRef(reasonIfUnsupported,
56 input0.GetDataType(),
57 &TrueFunc<>,
58 &TrueFunc<>);
59}
60
61bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +010062 const TensorInfo& output,
63 const TensorInfo& mean,
64 const TensorInfo& var,
65 const TensorInfo& beta,
66 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +000067 const BatchNormalizationDescriptor& descriptor,
68 std::string* reasonIfUnsupported)
69{
70 ignore_unused(descriptor);
71 return IsSupportedForDataTypeRef(reasonIfUnsupported,
72 input.GetDataType(),
73 &TrueFunc<>,
74 &TrueFunc<>);
75}
76
77bool IsConstantSupportedRef(const TensorInfo& output,
78 std::string* reasonIfUnsupported)
79{
80 return IsSupportedForDataTypeRef(reasonIfUnsupported,
81 output.GetDataType(),
82 &TrueFunc<>,
83 &TrueFunc<>);
84}
85
86bool IsConvolution2dSupportedRef(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +010087 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +000088 const Convolution2dDescriptor& descriptor,
89 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +010090 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +000091 std::string* reasonIfUnsupported)
92{
93 ignore_unused(descriptor);
surmeh013537c2c2018-05-18 16:31:43 +010094 ignore_unused(output);
95 ignore_unused(weights);
96 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +000097 return IsSupportedForDataTypeRef(reasonIfUnsupported,
98 input.GetDataType(),
99 &TrueFunc<>,
100 &TrueFunc<>);
101}
102
103bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100104 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000105 const DepthwiseConvolution2dDescriptor& descriptor,
106 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +0100107 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000108 std::string* reasonIfUnsupported)
109{
telsoa01c577f2c2018-08-31 09:22:23 +0100110 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000111 ignore_unused(descriptor);
112 ignore_unused(weights);
telsoa01c577f2c2018-08-31 09:22:23 +0100113 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000114 return IsSupportedForDataTypeRef(reasonIfUnsupported,
115 input.GetDataType(),
116 &TrueFunc<>,
117 &TrueFunc<>);
118}
119
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100120bool IsDivisionSupportedRef(const TensorInfo& input0,
121 const TensorInfo& input1,
122 const TensorInfo& output,
123 std::string* reasonIfUnsupported)
124{
125 ignore_unused(input1);
126 ignore_unused(output);
127 return IsSupportedForDataTypeRef(reasonIfUnsupported,
128 input0.GetDataType(),
129 &TrueFunc<>,
130 &TrueFunc<>);
131}
132
David Beckc2044fe2018-09-05 15:00:38 +0100133bool IsSubtractionSupportedRef(const TensorInfo& input0,
134 const TensorInfo& input1,
135 const TensorInfo& output,
136 std::string* reasonIfUnsupported)
137{
David Beckf195f032018-09-06 16:46:34 +0100138 ignore_unused(input1);
139 ignore_unused(output);
140 return IsSupportedForDataTypeRef(reasonIfUnsupported,
141 input0.GetDataType(),
142 &TrueFunc<>,
143 &TrueFunc<>);
David Beckc2044fe2018-09-05 15:00:38 +0100144}
145
telsoa014fcda012018-03-09 14:13:49 +0000146bool IsFullyConnectedSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100147 const TensorInfo& output,
148 const TensorInfo& weights,
149 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000150 const FullyConnectedDescriptor& descriptor,
151 std::string* reasonIfUnsupported)
152{
telsoa01c577f2c2018-08-31 09:22:23 +0100153 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000154 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100155 ignore_unused(weights);
156 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000157 return IsSupportedForDataTypeRef(reasonIfUnsupported,
158 input.GetDataType(),
159 &TrueFunc<>,
160 &TrueFunc<>);
161}
162
163bool IsInputSupportedRef(const TensorInfo& input,
164 std::string* reasonIfUnsupported)
165{
166 return IsSupportedForDataTypeRef(reasonIfUnsupported,
167 input.GetDataType(),
168 &TrueFunc<>,
169 &TrueFunc<>);
170}
171
172bool IsL2NormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100173 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000174 std::string* reasonIfUnsupported)
175{
telsoa01c577f2c2018-08-31 09:22:23 +0100176 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000177 return IsSupportedForDataTypeRef(reasonIfUnsupported,
178 input.GetDataType(),
179 &TrueFunc<>,
180 &FalseFuncU8<>);
181}
182
183bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
184 const OriginsDescriptor& descriptor,
185 std::string* reasonIfUnsupported)
186{
187 ignore_unused(descriptor);
188 return IsSupportedForDataTypeRef(reasonIfUnsupported,
189 inputs[0]->GetDataType(),
190 &TrueFunc<>,
191 &TrueFunc<>);
192}
193
194bool IsMultiplicationSupportedRef(const TensorInfo& input0,
195 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100196 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000197 std::string* reasonIfUnsupported)
198{
199 ignore_unused(input1);
telsoa01c577f2c2018-08-31 09:22:23 +0100200 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000201 return IsSupportedForDataTypeRef(reasonIfUnsupported,
202 input0.GetDataType(),
203 &TrueFunc<>,
204 &TrueFunc<>);
205}
206
207bool IsNormalizationSupportedRef(const TensorInfo& input,
208 const TensorInfo& output,
209 const NormalizationDescriptor& descriptor,
210 std::string* reasonIfUnsupported)
211{
212 ignore_unused(descriptor);
213 return IsSupportedForDataTypeRef(reasonIfUnsupported,
214 input.GetDataType(),
215 &TrueFunc<>,
216 &FalseFuncU8<>);
217}
218
219bool IsOutputSupportedRef(const TensorInfo& output,
220 std::string* reasonIfUnsupported)
221{
222 return IsSupportedForDataTypeRef(reasonIfUnsupported,
223 output.GetDataType(),
224 &TrueFunc<>,
225 &TrueFunc<>);
226}
227
228bool IsPermuteSupportedRef(const TensorInfo& input,
229 const TensorInfo& output,
230 const PermuteDescriptor& descriptor,
231 std::string* reasonIfUnsupported)
232{
233 ignore_unused(descriptor);
234 return IsSupportedForDataTypeRef(reasonIfUnsupported,
235 input.GetDataType(),
236 &TrueFunc<>,
237 &TrueFunc<>);
238}
239
240bool IsPooling2dSupportedRef(const TensorInfo& input,
241 const TensorInfo& output,
242 const Pooling2dDescriptor& descriptor,
243 std::string* reasonIfUnsupported)
244{
245 ignore_unused(descriptor);
246 return IsSupportedForDataTypeRef(reasonIfUnsupported,
247 input.GetDataType(),
248 &TrueFunc<>,
249 &TrueFunc<>);
250}
251
252bool IsResizeBilinearSupportedRef(const TensorInfo& input,
253 std::string* reasonIfUnsupported)
254{
255 return IsSupportedForDataTypeRef(reasonIfUnsupported,
256 input.GetDataType(),
257 &TrueFunc<>,
258 &TrueFunc<>);
259}
260
261bool IsSoftmaxSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100262 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000263 const SoftmaxDescriptor& descriptor,
264 std::string* reasonIfUnsupported)
265{
telsoa01c577f2c2018-08-31 09:22:23 +0100266 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000267 ignore_unused(descriptor);
268 return IsSupportedForDataTypeRef(reasonIfUnsupported,
269 input.GetDataType(),
270 &TrueFunc<>,
271 &TrueFunc<>);
272}
273
274bool IsSplitterSupportedRef(const TensorInfo& input,
275 const ViewsDescriptor& descriptor,
276 std::string* reasonIfUnsupported)
277{
278 ignore_unused(descriptor);
279 return IsSupportedForDataTypeRef(reasonIfUnsupported,
280 input.GetDataType(),
281 &TrueFunc<>,
282 &TrueFunc<>);
283}
284
285bool IsFakeQuantizationSupportedRef(const TensorInfo& input,
286 const FakeQuantizationDescriptor& descriptor,
287 std::string* reasonIfUnsupported)
288{
289 ignore_unused(descriptor);
290 return IsSupportedForDataTypeRef(reasonIfUnsupported,
291 input.GetDataType(),
292 &TrueFunc<>,
293 &FalseFuncU8<>);
294}
295
296bool IsReshapeSupportedRef(const TensorInfo& input,
297 std::string* reasonIfUnsupported)
298{
299 return IsSupportedForDataTypeRef(reasonIfUnsupported,
300 input.GetDataType(),
301 &TrueFunc<>,
302 &TrueFunc<>);
303}
304
305bool IsFloorSupportedRef(const TensorInfo& input,
306 const TensorInfo& output,
307 std::string* reasonIfUnsupported)
308{
309 ignore_unused(output);
310 return IsSupportedForDataTypeRef(reasonIfUnsupported,
311 input.GetDataType(),
312 &TrueFunc<>,
313 &FalseFuncU8<>);
314}
315
telsoa01c577f2c2018-08-31 09:22:23 +0100316bool IsLstmSupportedRef(const TensorInfo& input, const TensorInfo& outputStateIn,
317 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
318 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
319 const TensorInfo& output, const LstmDescriptor& descriptor,
320 const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
321 const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
322 const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
323 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
324 const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
325 const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
326 const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
327 const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
328 const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
329{
330 ignore_unused(input);
331 ignore_unused(outputStateIn);
332 ignore_unused(cellStateIn);
333 ignore_unused(scratchBuffer);
334 ignore_unused(outputStateOut);
335 ignore_unused(cellStateOut);
336 ignore_unused(output);
337 ignore_unused(descriptor);
338 ignore_unused(inputToForgetWeights);
339 ignore_unused(inputToCellWeights);
340 ignore_unused(inputToOutputWeights);
341 ignore_unused(recurrentToForgetWeights);
342 ignore_unused(recurrentToCellWeights);
343 ignore_unused(recurrentToOutputWeights);
344 ignore_unused(forgetGateBias);
345 ignore_unused(cellBias);
346 ignore_unused(outputGateBias);
347 ignore_unused(inputToInputWeights);
348 ignore_unused(recurrentToInputWeights);
349 ignore_unused(cellToInputWeights);
350 ignore_unused(inputGateBias);
351 ignore_unused(projectionWeights);
352 ignore_unused(projectionBias);
353 ignore_unused(cellToForgetWeights);
354 ignore_unused(cellToOutputWeights);
355 return false;
356}
357
358bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input,
359 const TensorInfo& output,
360 std::string* reasonIfUnsupported)
361{
362 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
363 input.GetDataType(),
364 &TrueFunc<>,
365 &FalseInputFuncF32<>,
366 &FalseFuncU8<>) &&
367 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
368 output.GetDataType(),
369 &FalseOutputFuncF16<>,
370 &TrueFunc<>,
371 &FalseFuncU8<>));
372}
373
374bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input,
375 const TensorInfo& output,
376 std::string* reasonIfUnsupported)
377{
378 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
379 input.GetDataType(),
380 &FalseInputFuncF16<>,
381 &TrueFunc<>,
382 &FalseFuncU8<>) &&
383 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
384 output.GetDataType(),
385 &TrueFunc<>,
386 &FalseOutputFuncF32<>,
387 &FalseFuncU8<>));
388}
389
narpra0132b90462018-09-13 11:07:48 +0100390bool IsMeanSupportedRef(const TensorInfo& input,
391 const TensorInfo& output,
392 const MeanDescriptor& descriptor,
393 std::string* reasonIfUnsupported)
394{
395 return false;
396}
397
telsoa014fcda012018-03-09 14:13:49 +0000398}