blob: 964c18e8ea1c942e1ac40a4e0b62fc4167159fbb [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include "LayerSupportCommon.hpp"
7#include "RefLayerSupport.hpp"
8#include <armnn/Descriptors.hpp>
9#include <armnn/Types.hpp>
10#include <armnn/Tensor.hpp>
11
12#include <boost/core/ignore_unused.hpp>
13
14#include "InternalTypes.hpp"
15
16using namespace boost;
17
18namespace armnn
19{
20
21template<typename Float32Func, typename Uint8Func, typename ... Params>
22bool IsSupportedForDataTypeRef(std::string* reasonIfUnsupported,
23 DataType dataType,
24 Float32Func floatFuncPtr,
25 Uint8Func uint8FuncPtr,
26 Params&&... params)
27{
28 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
29 dataType,
30 floatFuncPtr,
31 uint8FuncPtr,
32 std::forward<Params>(params)...);
33}
34
35bool IsActivationSupportedRef(const TensorInfo& input,
36 const ActivationDescriptor& descriptor,
37 std::string* reasonIfUnsupported)
38{
39 ignore_unused(descriptor);
40 return IsSupportedForDataTypeRef(reasonIfUnsupported,
41 input.GetDataType(),
42 &TrueFunc<>,
43 &TrueFunc<>);
44}
45
46bool IsAdditionSupportedRef(const TensorInfo& input0,
47 const TensorInfo& input1,
48 const TensorInfo& output,
49 std::string* reasonIfUnsupported)
50{
51 ignore_unused(input1);
52 ignore_unused(output);
53 return IsSupportedForDataTypeRef(reasonIfUnsupported,
54 input0.GetDataType(),
55 &TrueFunc<>,
56 &TrueFunc<>);
57}
58
59bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
60 const BatchNormalizationDescriptor& descriptor,
61 std::string* reasonIfUnsupported)
62{
63 ignore_unused(descriptor);
64 return IsSupportedForDataTypeRef(reasonIfUnsupported,
65 input.GetDataType(),
66 &TrueFunc<>,
67 &TrueFunc<>);
68}
69
70bool IsConstantSupportedRef(const TensorInfo& output,
71 std::string* reasonIfUnsupported)
72{
73 return IsSupportedForDataTypeRef(reasonIfUnsupported,
74 output.GetDataType(),
75 &TrueFunc<>,
76 &TrueFunc<>);
77}
78
79bool IsConvolution2dSupportedRef(const TensorInfo& input,
80 const Convolution2dDescriptor& descriptor,
81 const TensorInfo& weights,
82 std::string* reasonIfUnsupported)
83{
84 ignore_unused(descriptor);
85 return IsSupportedForDataTypeRef(reasonIfUnsupported,
86 input.GetDataType(),
87 &TrueFunc<>,
88 &TrueFunc<>);
89}
90
91bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
92 const DepthwiseConvolution2dDescriptor& descriptor,
93 const TensorInfo& weights,
94 std::string* reasonIfUnsupported)
95{
96 ignore_unused(descriptor);
97 ignore_unused(weights);
98 return IsSupportedForDataTypeRef(reasonIfUnsupported,
99 input.GetDataType(),
100 &TrueFunc<>,
101 &TrueFunc<>);
102}
103
104bool IsFullyConnectedSupportedRef(const TensorInfo& input,
105 const FullyConnectedDescriptor& descriptor,
106 std::string* reasonIfUnsupported)
107{
108 ignore_unused(descriptor);
109 return IsSupportedForDataTypeRef(reasonIfUnsupported,
110 input.GetDataType(),
111 &TrueFunc<>,
112 &TrueFunc<>);
113}
114
115bool IsInputSupportedRef(const TensorInfo& input,
116 std::string* reasonIfUnsupported)
117{
118 return IsSupportedForDataTypeRef(reasonIfUnsupported,
119 input.GetDataType(),
120 &TrueFunc<>,
121 &TrueFunc<>);
122}
123
124bool IsL2NormalizationSupportedRef(const TensorInfo& input,
125 std::string* reasonIfUnsupported)
126{
127 return IsSupportedForDataTypeRef(reasonIfUnsupported,
128 input.GetDataType(),
129 &TrueFunc<>,
130 &FalseFuncU8<>);
131}
132
133bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
134 const OriginsDescriptor& descriptor,
135 std::string* reasonIfUnsupported)
136{
137 ignore_unused(descriptor);
138 return IsSupportedForDataTypeRef(reasonIfUnsupported,
139 inputs[0]->GetDataType(),
140 &TrueFunc<>,
141 &TrueFunc<>);
142}
143
144bool IsMultiplicationSupportedRef(const TensorInfo& input0,
145 const TensorInfo& input1,
146 std::string* reasonIfUnsupported)
147{
148 ignore_unused(input1);
149 return IsSupportedForDataTypeRef(reasonIfUnsupported,
150 input0.GetDataType(),
151 &TrueFunc<>,
152 &TrueFunc<>);
153}
154
155bool IsNormalizationSupportedRef(const TensorInfo& input,
156 const TensorInfo& output,
157 const NormalizationDescriptor& descriptor,
158 std::string* reasonIfUnsupported)
159{
160 ignore_unused(descriptor);
161 return IsSupportedForDataTypeRef(reasonIfUnsupported,
162 input.GetDataType(),
163 &TrueFunc<>,
164 &FalseFuncU8<>);
165}
166
167bool IsOutputSupportedRef(const TensorInfo& output,
168 std::string* reasonIfUnsupported)
169{
170 return IsSupportedForDataTypeRef(reasonIfUnsupported,
171 output.GetDataType(),
172 &TrueFunc<>,
173 &TrueFunc<>);
174}
175
176bool IsPermuteSupportedRef(const TensorInfo& input,
177 const TensorInfo& output,
178 const PermuteDescriptor& descriptor,
179 std::string* reasonIfUnsupported)
180{
181 ignore_unused(descriptor);
182 return IsSupportedForDataTypeRef(reasonIfUnsupported,
183 input.GetDataType(),
184 &TrueFunc<>,
185 &TrueFunc<>);
186}
187
188bool IsPooling2dSupportedRef(const TensorInfo& input,
189 const TensorInfo& output,
190 const Pooling2dDescriptor& descriptor,
191 std::string* reasonIfUnsupported)
192{
193 ignore_unused(descriptor);
194 return IsSupportedForDataTypeRef(reasonIfUnsupported,
195 input.GetDataType(),
196 &TrueFunc<>,
197 &TrueFunc<>);
198}
199
200bool IsResizeBilinearSupportedRef(const TensorInfo& input,
201 std::string* reasonIfUnsupported)
202{
203 return IsSupportedForDataTypeRef(reasonIfUnsupported,
204 input.GetDataType(),
205 &TrueFunc<>,
206 &TrueFunc<>);
207}
208
209bool IsSoftmaxSupportedRef(const TensorInfo& input,
210 const SoftmaxDescriptor& descriptor,
211 std::string* reasonIfUnsupported)
212{
213 ignore_unused(descriptor);
214 return IsSupportedForDataTypeRef(reasonIfUnsupported,
215 input.GetDataType(),
216 &TrueFunc<>,
217 &TrueFunc<>);
218}
219
220bool IsSplitterSupportedRef(const TensorInfo& input,
221 const ViewsDescriptor& descriptor,
222 std::string* reasonIfUnsupported)
223{
224 ignore_unused(descriptor);
225 return IsSupportedForDataTypeRef(reasonIfUnsupported,
226 input.GetDataType(),
227 &TrueFunc<>,
228 &TrueFunc<>);
229}
230
231bool IsFakeQuantizationSupportedRef(const TensorInfo& input,
232 const FakeQuantizationDescriptor& descriptor,
233 std::string* reasonIfUnsupported)
234{
235 ignore_unused(descriptor);
236 return IsSupportedForDataTypeRef(reasonIfUnsupported,
237 input.GetDataType(),
238 &TrueFunc<>,
239 &FalseFuncU8<>);
240}
241
242bool IsReshapeSupportedRef(const TensorInfo& input,
243 std::string* reasonIfUnsupported)
244{
245 return IsSupportedForDataTypeRef(reasonIfUnsupported,
246 input.GetDataType(),
247 &TrueFunc<>,
248 &TrueFunc<>);
249}
250
251bool IsFloorSupportedRef(const TensorInfo& input,
252 const TensorInfo& output,
253 std::string* reasonIfUnsupported)
254{
255 ignore_unused(output);
256 return IsSupportedForDataTypeRef(reasonIfUnsupported,
257 input.GetDataType(),
258 &TrueFunc<>,
259 &FalseFuncU8<>);
260}
261
262}