blob: 1ca3d5b6d642201068b989a3daedb7b42f50ffc9 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "LayerSupportCommon.hpp"
7#include "RefLayerSupport.hpp"
8#include <armnn/Descriptors.hpp>
9#include <armnn/Types.hpp>
10#include <armnn/Tensor.hpp>
11
12#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013#include "InternalTypes.hpp"
14
15using namespace boost;
16
17namespace armnn
18{
19
20template<typename Float32Func, typename Uint8Func, typename ... Params>
21bool IsSupportedForDataTypeRef(std::string* reasonIfUnsupported,
22 DataType dataType,
23 Float32Func floatFuncPtr,
24 Uint8Func uint8FuncPtr,
25 Params&&... params)
26{
27 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
28 dataType,
telsoa01c577f2c2018-08-31 09:22:23 +010029 &FalseFunc<Params...>,
telsoa014fcda012018-03-09 14:13:49 +000030 floatFuncPtr,
31 uint8FuncPtr,
32 std::forward<Params>(params)...);
33}
34
35bool IsActivationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +010036 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +000037 const ActivationDescriptor& descriptor,
38 std::string* reasonIfUnsupported)
39{
telsoa01c577f2c2018-08-31 09:22:23 +010040 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +000041 ignore_unused(descriptor);
42 return IsSupportedForDataTypeRef(reasonIfUnsupported,
43 input.GetDataType(),
44 &TrueFunc<>,
45 &TrueFunc<>);
46}
47
48bool IsAdditionSupportedRef(const TensorInfo& input0,
49 const TensorInfo& input1,
50 const TensorInfo& output,
51 std::string* reasonIfUnsupported)
52{
53 ignore_unused(input1);
54 ignore_unused(output);
55 return IsSupportedForDataTypeRef(reasonIfUnsupported,
56 input0.GetDataType(),
57 &TrueFunc<>,
58 &TrueFunc<>);
59}
60
61bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +010062 const TensorInfo& output,
63 const TensorInfo& mean,
64 const TensorInfo& var,
65 const TensorInfo& beta,
66 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +000067 const BatchNormalizationDescriptor& descriptor,
68 std::string* reasonIfUnsupported)
69{
70 ignore_unused(descriptor);
71 return IsSupportedForDataTypeRef(reasonIfUnsupported,
72 input.GetDataType(),
73 &TrueFunc<>,
74 &TrueFunc<>);
75}
76
77bool IsConstantSupportedRef(const TensorInfo& output,
78 std::string* reasonIfUnsupported)
79{
80 return IsSupportedForDataTypeRef(reasonIfUnsupported,
81 output.GetDataType(),
82 &TrueFunc<>,
83 &TrueFunc<>);
84}
85
86bool IsConvolution2dSupportedRef(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +010087 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +000088 const Convolution2dDescriptor& descriptor,
89 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +010090 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +000091 std::string* reasonIfUnsupported)
92{
93 ignore_unused(descriptor);
surmeh013537c2c2018-05-18 16:31:43 +010094 ignore_unused(output);
95 ignore_unused(weights);
96 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +000097 return IsSupportedForDataTypeRef(reasonIfUnsupported,
98 input.GetDataType(),
99 &TrueFunc<>,
100 &TrueFunc<>);
101}
102
103bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100104 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000105 const DepthwiseConvolution2dDescriptor& descriptor,
106 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +0100107 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000108 std::string* reasonIfUnsupported)
109{
telsoa01c577f2c2018-08-31 09:22:23 +0100110 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000111 ignore_unused(descriptor);
112 ignore_unused(weights);
telsoa01c577f2c2018-08-31 09:22:23 +0100113 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000114 return IsSupportedForDataTypeRef(reasonIfUnsupported,
115 input.GetDataType(),
116 &TrueFunc<>,
117 &TrueFunc<>);
118}
119
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100120bool IsDivisionSupportedRef(const TensorInfo& input0,
121 const TensorInfo& input1,
122 const TensorInfo& output,
123 std::string* reasonIfUnsupported)
124{
125 ignore_unused(input1);
126 ignore_unused(output);
127 return IsSupportedForDataTypeRef(reasonIfUnsupported,
128 input0.GetDataType(),
129 &TrueFunc<>,
130 &TrueFunc<>);
131}
132
David Beckc2044fe2018-09-05 15:00:38 +0100133bool IsSubtractionSupportedRef(const TensorInfo& input0,
134 const TensorInfo& input1,
135 const TensorInfo& output,
136 std::string* reasonIfUnsupported)
137{
David Beckf195f032018-09-06 16:46:34 +0100138 ignore_unused(input1);
139 ignore_unused(output);
140 return IsSupportedForDataTypeRef(reasonIfUnsupported,
141 input0.GetDataType(),
142 &TrueFunc<>,
143 &TrueFunc<>);
David Beckc2044fe2018-09-05 15:00:38 +0100144}
145
telsoa014fcda012018-03-09 14:13:49 +0000146bool IsFullyConnectedSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100147 const TensorInfo& output,
148 const TensorInfo& weights,
149 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000150 const FullyConnectedDescriptor& descriptor,
151 std::string* reasonIfUnsupported)
152{
telsoa01c577f2c2018-08-31 09:22:23 +0100153 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000154 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100155 ignore_unused(weights);
156 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000157 return IsSupportedForDataTypeRef(reasonIfUnsupported,
158 input.GetDataType(),
159 &TrueFunc<>,
160 &TrueFunc<>);
161}
162
163bool IsInputSupportedRef(const TensorInfo& input,
164 std::string* reasonIfUnsupported)
165{
166 return IsSupportedForDataTypeRef(reasonIfUnsupported,
167 input.GetDataType(),
168 &TrueFunc<>,
169 &TrueFunc<>);
170}
171
172bool IsL2NormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100173 const TensorInfo& output,
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100174 const L2NormalizationDescriptor& descriptor,
telsoa014fcda012018-03-09 14:13:49 +0000175 std::string* reasonIfUnsupported)
176{
telsoa01c577f2c2018-08-31 09:22:23 +0100177 ignore_unused(output);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100178 ignore_unused(descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000179 return IsSupportedForDataTypeRef(reasonIfUnsupported,
180 input.GetDataType(),
181 &TrueFunc<>,
182 &FalseFuncU8<>);
183}
184
185bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
186 const OriginsDescriptor& descriptor,
187 std::string* reasonIfUnsupported)
188{
189 ignore_unused(descriptor);
190 return IsSupportedForDataTypeRef(reasonIfUnsupported,
191 inputs[0]->GetDataType(),
192 &TrueFunc<>,
193 &TrueFunc<>);
194}
195
196bool IsMultiplicationSupportedRef(const TensorInfo& input0,
197 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100198 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000199 std::string* reasonIfUnsupported)
200{
201 ignore_unused(input1);
telsoa01c577f2c2018-08-31 09:22:23 +0100202 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000203 return IsSupportedForDataTypeRef(reasonIfUnsupported,
204 input0.GetDataType(),
205 &TrueFunc<>,
206 &TrueFunc<>);
207}
208
209bool IsNormalizationSupportedRef(const TensorInfo& input,
210 const TensorInfo& output,
211 const NormalizationDescriptor& descriptor,
212 std::string* reasonIfUnsupported)
213{
214 ignore_unused(descriptor);
215 return IsSupportedForDataTypeRef(reasonIfUnsupported,
216 input.GetDataType(),
217 &TrueFunc<>,
218 &FalseFuncU8<>);
219}
220
221bool IsOutputSupportedRef(const TensorInfo& output,
222 std::string* reasonIfUnsupported)
223{
224 return IsSupportedForDataTypeRef(reasonIfUnsupported,
225 output.GetDataType(),
226 &TrueFunc<>,
227 &TrueFunc<>);
228}
229
230bool IsPermuteSupportedRef(const TensorInfo& input,
231 const TensorInfo& output,
232 const PermuteDescriptor& descriptor,
233 std::string* reasonIfUnsupported)
234{
235 ignore_unused(descriptor);
236 return IsSupportedForDataTypeRef(reasonIfUnsupported,
237 input.GetDataType(),
238 &TrueFunc<>,
239 &TrueFunc<>);
240}
241
242bool IsPooling2dSupportedRef(const TensorInfo& input,
243 const TensorInfo& output,
244 const Pooling2dDescriptor& descriptor,
245 std::string* reasonIfUnsupported)
246{
247 ignore_unused(descriptor);
248 return IsSupportedForDataTypeRef(reasonIfUnsupported,
249 input.GetDataType(),
250 &TrueFunc<>,
251 &TrueFunc<>);
252}
253
254bool IsResizeBilinearSupportedRef(const TensorInfo& input,
255 std::string* reasonIfUnsupported)
256{
257 return IsSupportedForDataTypeRef(reasonIfUnsupported,
258 input.GetDataType(),
259 &TrueFunc<>,
260 &TrueFunc<>);
261}
262
263bool IsSoftmaxSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100264 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000265 const SoftmaxDescriptor& descriptor,
266 std::string* reasonIfUnsupported)
267{
telsoa01c577f2c2018-08-31 09:22:23 +0100268 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000269 ignore_unused(descriptor);
270 return IsSupportedForDataTypeRef(reasonIfUnsupported,
271 input.GetDataType(),
272 &TrueFunc<>,
273 &TrueFunc<>);
274}
275
276bool IsSplitterSupportedRef(const TensorInfo& input,
277 const ViewsDescriptor& descriptor,
278 std::string* reasonIfUnsupported)
279{
280 ignore_unused(descriptor);
281 return IsSupportedForDataTypeRef(reasonIfUnsupported,
282 input.GetDataType(),
283 &TrueFunc<>,
284 &TrueFunc<>);
285}
286
287bool IsFakeQuantizationSupportedRef(const TensorInfo& input,
288 const FakeQuantizationDescriptor& descriptor,
289 std::string* reasonIfUnsupported)
290{
291 ignore_unused(descriptor);
292 return IsSupportedForDataTypeRef(reasonIfUnsupported,
293 input.GetDataType(),
294 &TrueFunc<>,
295 &FalseFuncU8<>);
296}
297
298bool IsReshapeSupportedRef(const TensorInfo& input,
299 std::string* reasonIfUnsupported)
300{
301 return IsSupportedForDataTypeRef(reasonIfUnsupported,
302 input.GetDataType(),
303 &TrueFunc<>,
304 &TrueFunc<>);
305}
306
307bool IsFloorSupportedRef(const TensorInfo& input,
308 const TensorInfo& output,
309 std::string* reasonIfUnsupported)
310{
311 ignore_unused(output);
312 return IsSupportedForDataTypeRef(reasonIfUnsupported,
313 input.GetDataType(),
314 &TrueFunc<>,
315 &FalseFuncU8<>);
316}
317
telsoa01c577f2c2018-08-31 09:22:23 +0100318bool IsLstmSupportedRef(const TensorInfo& input, const TensorInfo& outputStateIn,
319 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
320 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
321 const TensorInfo& output, const LstmDescriptor& descriptor,
322 const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
323 const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
324 const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
325 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
326 const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
327 const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
328 const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
329 const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
330 const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
331{
332 ignore_unused(input);
333 ignore_unused(outputStateIn);
334 ignore_unused(cellStateIn);
335 ignore_unused(scratchBuffer);
336 ignore_unused(outputStateOut);
337 ignore_unused(cellStateOut);
338 ignore_unused(output);
339 ignore_unused(descriptor);
340 ignore_unused(inputToForgetWeights);
341 ignore_unused(inputToCellWeights);
342 ignore_unused(inputToOutputWeights);
343 ignore_unused(recurrentToForgetWeights);
344 ignore_unused(recurrentToCellWeights);
345 ignore_unused(recurrentToOutputWeights);
346 ignore_unused(forgetGateBias);
347 ignore_unused(cellBias);
348 ignore_unused(outputGateBias);
349 ignore_unused(inputToInputWeights);
350 ignore_unused(recurrentToInputWeights);
351 ignore_unused(cellToInputWeights);
352 ignore_unused(inputGateBias);
353 ignore_unused(projectionWeights);
354 ignore_unused(projectionBias);
355 ignore_unused(cellToForgetWeights);
356 ignore_unused(cellToOutputWeights);
357 return false;
358}
359
360bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input,
361 const TensorInfo& output,
362 std::string* reasonIfUnsupported)
363{
364 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
365 input.GetDataType(),
366 &TrueFunc<>,
367 &FalseInputFuncF32<>,
368 &FalseFuncU8<>) &&
369 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
370 output.GetDataType(),
371 &FalseOutputFuncF16<>,
372 &TrueFunc<>,
373 &FalseFuncU8<>));
374}
375
376bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input,
377 const TensorInfo& output,
378 std::string* reasonIfUnsupported)
379{
380 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
381 input.GetDataType(),
382 &FalseInputFuncF16<>,
383 &TrueFunc<>,
384 &FalseFuncU8<>) &&
385 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
386 output.GetDataType(),
387 &TrueFunc<>,
388 &FalseOutputFuncF32<>,
389 &FalseFuncU8<>));
390}
391
narpra0132b90462018-09-13 11:07:48 +0100392bool IsMeanSupportedRef(const TensorInfo& input,
393 const TensorInfo& output,
394 const MeanDescriptor& descriptor,
395 std::string* reasonIfUnsupported)
396{
narpra011e4c31d2018-09-28 11:07:51 +0100397 ignore_unused(output);
398 ignore_unused(descriptor);
399 return IsSupportedForDataTypeRef(reasonIfUnsupported,
400 input.GetDataType(),
401 &TrueFunc<>,
402 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100403}
404
Nina Drozd661dfa72018-10-02 11:14:17 +0100405bool IsPadSupportedRef(const TensorInfo& input,
406 const TensorInfo& output,
407 const PadDescriptor& descriptor,
408 std::string* reasonIfUnsupported)
409{
410 ignore_unused(output);
411 ignore_unused(descriptor);
412 return false;
413}
414
telsoa014fcda012018-03-09 14:13:49 +0000415}