blob: ee91e73df2d05373520fc8fc7a46a543e1150b7c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "LayerSupportCommon.hpp"
7#include "RefLayerSupport.hpp"
8#include <armnn/Descriptors.hpp>
9#include <armnn/Types.hpp>
10#include <armnn/Tensor.hpp>
11
12#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013#include "InternalTypes.hpp"
14
15using namespace boost;
16
17namespace armnn
18{
19
20template<typename Float32Func, typename Uint8Func, typename ... Params>
21bool IsSupportedForDataTypeRef(std::string* reasonIfUnsupported,
22 DataType dataType,
23 Float32Func floatFuncPtr,
24 Uint8Func uint8FuncPtr,
25 Params&&... params)
26{
27 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
28 dataType,
telsoa01c577f2c2018-08-31 09:22:23 +010029 &FalseFunc<Params...>,
telsoa014fcda012018-03-09 14:13:49 +000030 floatFuncPtr,
31 uint8FuncPtr,
32 std::forward<Params>(params)...);
33}
34
35bool IsActivationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +010036 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +000037 const ActivationDescriptor& descriptor,
38 std::string* reasonIfUnsupported)
39{
telsoa01c577f2c2018-08-31 09:22:23 +010040 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +000041 ignore_unused(descriptor);
42 return IsSupportedForDataTypeRef(reasonIfUnsupported,
43 input.GetDataType(),
44 &TrueFunc<>,
45 &TrueFunc<>);
46}
47
48bool IsAdditionSupportedRef(const TensorInfo& input0,
49 const TensorInfo& input1,
50 const TensorInfo& output,
51 std::string* reasonIfUnsupported)
52{
53 ignore_unused(input1);
54 ignore_unused(output);
55 return IsSupportedForDataTypeRef(reasonIfUnsupported,
56 input0.GetDataType(),
57 &TrueFunc<>,
58 &TrueFunc<>);
59}
60
61bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +010062 const TensorInfo& output,
63 const TensorInfo& mean,
64 const TensorInfo& var,
65 const TensorInfo& beta,
66 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +000067 const BatchNormalizationDescriptor& descriptor,
68 std::string* reasonIfUnsupported)
69{
70 ignore_unused(descriptor);
71 return IsSupportedForDataTypeRef(reasonIfUnsupported,
72 input.GetDataType(),
73 &TrueFunc<>,
74 &TrueFunc<>);
75}
76
77bool IsConstantSupportedRef(const TensorInfo& output,
78 std::string* reasonIfUnsupported)
79{
80 return IsSupportedForDataTypeRef(reasonIfUnsupported,
81 output.GetDataType(),
82 &TrueFunc<>,
83 &TrueFunc<>);
84}
85
86bool IsConvolution2dSupportedRef(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +010087 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +000088 const Convolution2dDescriptor& descriptor,
89 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +010090 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +000091 std::string* reasonIfUnsupported)
92{
93 ignore_unused(descriptor);
surmeh013537c2c2018-05-18 16:31:43 +010094 ignore_unused(output);
95 ignore_unused(weights);
96 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +000097 return IsSupportedForDataTypeRef(reasonIfUnsupported,
98 input.GetDataType(),
99 &TrueFunc<>,
100 &TrueFunc<>);
101}
102
103bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100104 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000105 const DepthwiseConvolution2dDescriptor& descriptor,
106 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +0100107 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000108 std::string* reasonIfUnsupported)
109{
telsoa01c577f2c2018-08-31 09:22:23 +0100110 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000111 ignore_unused(descriptor);
112 ignore_unused(weights);
telsoa01c577f2c2018-08-31 09:22:23 +0100113 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000114 return IsSupportedForDataTypeRef(reasonIfUnsupported,
115 input.GetDataType(),
116 &TrueFunc<>,
117 &TrueFunc<>);
118}
119
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100120bool IsDivisionSupportedRef(const TensorInfo& input0,
121 const TensorInfo& input1,
122 const TensorInfo& output,
123 std::string* reasonIfUnsupported)
124{
125 ignore_unused(input1);
126 ignore_unused(output);
127 return IsSupportedForDataTypeRef(reasonIfUnsupported,
128 input0.GetDataType(),
129 &TrueFunc<>,
130 &TrueFunc<>);
131}
132
telsoa014fcda012018-03-09 14:13:49 +0000133bool IsFullyConnectedSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100134 const TensorInfo& output,
135 const TensorInfo& weights,
136 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000137 const FullyConnectedDescriptor& descriptor,
138 std::string* reasonIfUnsupported)
139{
telsoa01c577f2c2018-08-31 09:22:23 +0100140 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000141 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100142 ignore_unused(weights);
143 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000144 return IsSupportedForDataTypeRef(reasonIfUnsupported,
145 input.GetDataType(),
146 &TrueFunc<>,
147 &TrueFunc<>);
148}
149
150bool IsInputSupportedRef(const TensorInfo& input,
151 std::string* reasonIfUnsupported)
152{
153 return IsSupportedForDataTypeRef(reasonIfUnsupported,
154 input.GetDataType(),
155 &TrueFunc<>,
156 &TrueFunc<>);
157}
158
159bool IsL2NormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100160 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000161 std::string* reasonIfUnsupported)
162{
telsoa01c577f2c2018-08-31 09:22:23 +0100163 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000164 return IsSupportedForDataTypeRef(reasonIfUnsupported,
165 input.GetDataType(),
166 &TrueFunc<>,
167 &FalseFuncU8<>);
168}
169
170bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
171 const OriginsDescriptor& descriptor,
172 std::string* reasonIfUnsupported)
173{
174 ignore_unused(descriptor);
175 return IsSupportedForDataTypeRef(reasonIfUnsupported,
176 inputs[0]->GetDataType(),
177 &TrueFunc<>,
178 &TrueFunc<>);
179}
180
181bool IsMultiplicationSupportedRef(const TensorInfo& input0,
182 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100183 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000184 std::string* reasonIfUnsupported)
185{
186 ignore_unused(input1);
telsoa01c577f2c2018-08-31 09:22:23 +0100187 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000188 return IsSupportedForDataTypeRef(reasonIfUnsupported,
189 input0.GetDataType(),
190 &TrueFunc<>,
191 &TrueFunc<>);
192}
193
194bool IsNormalizationSupportedRef(const TensorInfo& input,
195 const TensorInfo& output,
196 const NormalizationDescriptor& descriptor,
197 std::string* reasonIfUnsupported)
198{
199 ignore_unused(descriptor);
200 return IsSupportedForDataTypeRef(reasonIfUnsupported,
201 input.GetDataType(),
202 &TrueFunc<>,
203 &FalseFuncU8<>);
204}
205
206bool IsOutputSupportedRef(const TensorInfo& output,
207 std::string* reasonIfUnsupported)
208{
209 return IsSupportedForDataTypeRef(reasonIfUnsupported,
210 output.GetDataType(),
211 &TrueFunc<>,
212 &TrueFunc<>);
213}
214
215bool IsPermuteSupportedRef(const TensorInfo& input,
216 const TensorInfo& output,
217 const PermuteDescriptor& descriptor,
218 std::string* reasonIfUnsupported)
219{
220 ignore_unused(descriptor);
221 return IsSupportedForDataTypeRef(reasonIfUnsupported,
222 input.GetDataType(),
223 &TrueFunc<>,
224 &TrueFunc<>);
225}
226
227bool IsPooling2dSupportedRef(const TensorInfo& input,
228 const TensorInfo& output,
229 const Pooling2dDescriptor& descriptor,
230 std::string* reasonIfUnsupported)
231{
232 ignore_unused(descriptor);
233 return IsSupportedForDataTypeRef(reasonIfUnsupported,
234 input.GetDataType(),
235 &TrueFunc<>,
236 &TrueFunc<>);
237}
238
239bool IsResizeBilinearSupportedRef(const TensorInfo& input,
240 std::string* reasonIfUnsupported)
241{
242 return IsSupportedForDataTypeRef(reasonIfUnsupported,
243 input.GetDataType(),
244 &TrueFunc<>,
245 &TrueFunc<>);
246}
247
248bool IsSoftmaxSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100249 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000250 const SoftmaxDescriptor& descriptor,
251 std::string* reasonIfUnsupported)
252{
telsoa01c577f2c2018-08-31 09:22:23 +0100253 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000254 ignore_unused(descriptor);
255 return IsSupportedForDataTypeRef(reasonIfUnsupported,
256 input.GetDataType(),
257 &TrueFunc<>,
258 &TrueFunc<>);
259}
260
261bool IsSplitterSupportedRef(const TensorInfo& input,
262 const ViewsDescriptor& descriptor,
263 std::string* reasonIfUnsupported)
264{
265 ignore_unused(descriptor);
266 return IsSupportedForDataTypeRef(reasonIfUnsupported,
267 input.GetDataType(),
268 &TrueFunc<>,
269 &TrueFunc<>);
270}
271
272bool IsFakeQuantizationSupportedRef(const TensorInfo& input,
273 const FakeQuantizationDescriptor& descriptor,
274 std::string* reasonIfUnsupported)
275{
276 ignore_unused(descriptor);
277 return IsSupportedForDataTypeRef(reasonIfUnsupported,
278 input.GetDataType(),
279 &TrueFunc<>,
280 &FalseFuncU8<>);
281}
282
283bool IsReshapeSupportedRef(const TensorInfo& input,
284 std::string* reasonIfUnsupported)
285{
286 return IsSupportedForDataTypeRef(reasonIfUnsupported,
287 input.GetDataType(),
288 &TrueFunc<>,
289 &TrueFunc<>);
290}
291
292bool IsFloorSupportedRef(const TensorInfo& input,
293 const TensorInfo& output,
294 std::string* reasonIfUnsupported)
295{
296 ignore_unused(output);
297 return IsSupportedForDataTypeRef(reasonIfUnsupported,
298 input.GetDataType(),
299 &TrueFunc<>,
300 &FalseFuncU8<>);
301}
302
telsoa01c577f2c2018-08-31 09:22:23 +0100303bool IsLstmSupportedRef(const TensorInfo& input, const TensorInfo& outputStateIn,
304 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
305 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
306 const TensorInfo& output, const LstmDescriptor& descriptor,
307 const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
308 const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
309 const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
310 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
311 const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
312 const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
313 const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
314 const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
315 const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
316{
317 ignore_unused(input);
318 ignore_unused(outputStateIn);
319 ignore_unused(cellStateIn);
320 ignore_unused(scratchBuffer);
321 ignore_unused(outputStateOut);
322 ignore_unused(cellStateOut);
323 ignore_unused(output);
324 ignore_unused(descriptor);
325 ignore_unused(inputToForgetWeights);
326 ignore_unused(inputToCellWeights);
327 ignore_unused(inputToOutputWeights);
328 ignore_unused(recurrentToForgetWeights);
329 ignore_unused(recurrentToCellWeights);
330 ignore_unused(recurrentToOutputWeights);
331 ignore_unused(forgetGateBias);
332 ignore_unused(cellBias);
333 ignore_unused(outputGateBias);
334 ignore_unused(inputToInputWeights);
335 ignore_unused(recurrentToInputWeights);
336 ignore_unused(cellToInputWeights);
337 ignore_unused(inputGateBias);
338 ignore_unused(projectionWeights);
339 ignore_unused(projectionBias);
340 ignore_unused(cellToForgetWeights);
341 ignore_unused(cellToOutputWeights);
342 return false;
343}
344
345bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input,
346 const TensorInfo& output,
347 std::string* reasonIfUnsupported)
348{
349 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
350 input.GetDataType(),
351 &TrueFunc<>,
352 &FalseInputFuncF32<>,
353 &FalseFuncU8<>) &&
354 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
355 output.GetDataType(),
356 &FalseOutputFuncF16<>,
357 &TrueFunc<>,
358 &FalseFuncU8<>));
359}
360
361bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input,
362 const TensorInfo& output,
363 std::string* reasonIfUnsupported)
364{
365 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
366 input.GetDataType(),
367 &FalseInputFuncF16<>,
368 &TrueFunc<>,
369 &FalseFuncU8<>) &&
370 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
371 output.GetDataType(),
372 &TrueFunc<>,
373 &FalseOutputFuncF32<>,
374 &FalseFuncU8<>));
375}
376
telsoa014fcda012018-03-09 14:13:49 +0000377}