blob: 5f0e4ea6226c6bd7030365809f1b247f150d5993 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include "LayerSupportCommon.hpp"
7
8#include "ClLayerSupport.hpp"
9#include "InternalTypes.hpp"
10
11#include <armnn/Descriptors.hpp>
12#include <armnn/Types.hpp>
13#include <armnn/Tensor.hpp>
14
15#include <boost/core/ignore_unused.hpp>
16
17#ifdef ARMCOMPUTECL_ENABLED
18#include "ClWorkloads/ClAdditionFloat32Workload.hpp"
19#include "ClWorkloads/ClPooling2dBaseWorkload.hpp"
20#include "ClWorkloads/ClPermuteWorkload.hpp"
21#include "ClWorkloads/ClNormalizationFloat32Workload.hpp"
22#endif
23
24using namespace boost;
25
26namespace armnn
27{
28namespace
29{
30template<unsigned int FilterSize>
31bool IsMatchingSize2d(const TensorInfo& weightInfo)
32{
33 // Width & Height must match
34 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
35}
36
37template<uint32_t ValidStride>
38bool IsMatchingStride(uint32_t actualStride)
39{
40 return ValidStride == actualStride;
41}
42
43template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
44bool IsMatchingStride(uint32_t actualStride)
45{
46 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
47};
48
49bool IsClBackendSupported(std::string* reasonIfUnsupported)
50{
51#if ARMCOMPUTECL_ENABLED
52 return true;
53#else
54 if (reasonIfUnsupported != nullptr)
55 {
56 *reasonIfUnsupported = "The armnn library has been built without CL support";
57 }
58 return false;
59#endif
60}
61
62#if ARMCOMPUTECL_ENABLED
63#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
64#else
65#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
66#endif
67
68#if ARMCOMPUTECL_ENABLED
69template<class FuncType, class... Args>
70inline bool IsWorkloadSupported(FuncType&& func, std::string* reasonIfUnsupported, Args&&... args)
71{
72 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
73 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
74 if (!supported && reasonIfUnsupported)
75 {
76 *reasonIfUnsupported = aclStatus.error_description();
77 }
78 return supported;
79}
80
81#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
82 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
83#else
84#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
85 return IsClBackendSupported(reasonIfUnsupported);
86#endif
87
88} //namespace
89
90bool IsClActivationUint8Supported(std::string* reasonIfUnsupported, const ActivationDescriptor& parameters)
91{
92 if (parameters.m_Function != ActivationFunction::BoundedReLu)
93 {
94 if (reasonIfUnsupported)
95 {
96 *reasonIfUnsupported = "Unsupported activation function, only BoundedReLu is supported";
97 }
98
99 return false;
100 }
101
102 return true;
103}
104
105bool IsClDepthwiseConvolution2dDescParamsSupported(std::string* reasonIfUnsupported,
106 const DepthwiseConvolution2dDescriptor& parameters,
107 const TensorInfo& weights)
108{
109 if (weights.GetNumDimensions() != 4)
110 {
111 if (reasonIfUnsupported)
112 {
113 *reasonIfUnsupported = "Depwthwise convolution Weight tensor needs to be 4d";
114 }
115 return false;
116 }
117 // weights.GetShape()[0] = channel multiplier
118 if (weights.GetShape()[0] != 1)
119 {
120 if (reasonIfUnsupported)
121 {
122 *reasonIfUnsupported = "Channel multiplier only supports the value 1 in the CL backend";
123 }
124 return false;
125 }
126 else if ((weights.GetDataType() == armnn::DataType::QuantisedAsymm8) && !IsMatchingSize2d<3>(weights))
127 {
128 if (reasonIfUnsupported)
129 {
130 *reasonIfUnsupported = "CL backend only supports 3x3 filtering for Depthwise Convolution on 8-bit";
131 }
132 return false;
133 }
134
135 return true;
136}
137
138template<typename Float32Func, typename Uint8Func, typename ... Params>
139bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported,
140 DataType dataType,
141 Float32Func floatFuncPtr,
142 Uint8Func uint8FuncPtr,
143 Params&&... params)
144{
145 return IsClBackendSupported(reasonIfUnsupported) &&
146 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
147 dataType,
148 floatFuncPtr,
149 uint8FuncPtr,
150 std::forward<Params>(params)...);
151}
152
153bool IsActivationSupportedCl(const TensorInfo& input,
154 const ActivationDescriptor& descriptor,
155 std::string* reasonIfUnsupported)
156{
157 return IsSupportedForDataTypeCl(reasonIfUnsupported,
158 input.GetDataType(),
159 &TrueFunc<const ActivationDescriptor&>,
160 &IsClActivationUint8Supported,
161 descriptor);
162}
163
164bool IsAdditionSupportedCl(const TensorInfo& input0,
165 const TensorInfo& input1,
166 const TensorInfo& output,
167 std::string* reasonIfUnsupported)
168{
169 return FORWARD_CL_LAYER_SUPPORT_FUNC(ClAdditionFloat32Workload::IsSupported(input0,
170 input1,
171 output,
172 reasonIfUnsupported));
173}
174
175bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
176 const BatchNormalizationDescriptor& descriptor,
177 std::string* reasonIfUnsupported)
178{
179 return IsSupportedForDataTypeCl(reasonIfUnsupported,
180 input.GetDataType(),
181 &TrueFunc<const BatchNormalizationDescriptor&>,
182 &FalseFuncU8<const BatchNormalizationDescriptor&>,
183 descriptor);
184}
185
186bool IsConstantSupportedCl(const TensorInfo& output,
187 std::string* reasonIfUnsupported)
188{
189 return IsSupportedForDataTypeCl(reasonIfUnsupported,
190 output.GetDataType(),
191 &TrueFunc<>,
192 &FalseFuncU8<>);
193}
194
195bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc)
196{
197 bool isSupported = false;
198
199 bool strideXIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideX);
200 bool strideXIsThree = IsMatchingStride<3>(desc.m_StrideX);
201
202 bool strideYIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideY);
203 bool strideYIsThree = IsMatchingStride<3>(desc.m_StrideY);
204
205 bool strideIsOneOrTwo = strideXIsOneOrTwo && strideYIsOneOrTwo;
206 bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
207
208 // 1x1 convolution with strides of 1,2,3
209 isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
210
211 // 3x3 convolution with strides of 1,2
212 isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
213
214 // 5x5 convolution with strides of 1,2
215 isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
216
217 //fall back to normal convolution for the asymmetric padding case.
218 if (desc.m_PadLeft != desc.m_PadRight ||
219 desc.m_PadTop != desc.m_PadBottom)
220 {
221 //direct convolution does not support asymmetric padding yet.
222 isSupported = false;
223 }
224
225 return isSupported;
226}
227
228bool IsDirectConvolution2dParamsSupportedCl(std::string* reasonIfUnsupported,
229 const Convolution2dDescriptor& parameters,
230 const TensorInfo& weightInfo)
231{
232 return IsClDirectConvolution2dSupported(weightInfo, parameters);
233}
234
235bool IsConvolution2dSupportedCl(const TensorInfo& input,
236 const Convolution2dDescriptor& descriptor,
237 const TensorInfo& weights,
238 std::string* reasonIfUnsupported)
239{
240 return IsSupportedForDataTypeCl(reasonIfUnsupported,
241 input.GetDataType(),
242 &TrueFunc<decltype(descriptor), decltype(weights)>,
243 &IsDirectConvolution2dParamsSupportedCl,
244 descriptor,
245 weights);
246}
247
248bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
249 const DepthwiseConvolution2dDescriptor& descriptor,
250 const TensorInfo& weights,
251 std::string* reasonIfUnsupported)
252{
253 return IsSupportedForDataTypeCl(reasonIfUnsupported,
254 input.GetDataType(),
255 &IsClDepthwiseConvolution2dDescParamsSupported,
256 &IsClDepthwiseConvolution2dDescParamsSupported,
257 descriptor,
258 weights);
259}
260
261bool IsFullyConnectedSupportedCl(const TensorInfo& input,
262 const FullyConnectedDescriptor& descriptor,
263 std::string* reasonIfUnsupported)
264{
265 ignore_unused(descriptor);
266 return IsSupportedForDataTypeCl(reasonIfUnsupported,
267 input.GetDataType(),
268 &TrueFunc<>,
269 &FalseFuncU8<>);
270}
271
272bool IsInputSupportedCl(const TensorInfo& input,
273 std::string* reasonIfUnsupported)
274{
275 return IsSupportedForDataTypeCl(reasonIfUnsupported,
276 input.GetDataType(),
277 &TrueFunc<>,
278 &TrueFunc<>);
279}
280
281bool IsL2NormalizationSupportedCl(const TensorInfo& input,
282 std::string* reasonIfUnsupported)
283{
284 return IsSupportedForDataTypeCl(reasonIfUnsupported,
285 input.GetDataType(),
286 &TrueFunc<>,
287 &FalseFuncU8<>);
288}
289
290bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
291 const OriginsDescriptor& descriptor,
292 std::string* reasonIfUnsupported)
293{
294 ignore_unused(descriptor);
295 return IsSupportedForDataTypeCl(reasonIfUnsupported,
296 inputs[0]->GetDataType(),
297 &TrueFunc<>,
298 &FalseFuncU8<>);
299}
300
301bool IsMultiplicationSupportedCl(const TensorInfo& input0,
302 const TensorInfo& input1,
303 std::string* reasonIfUnsupported)
304{
305 ignore_unused(input1);
306 return IsSupportedForDataTypeCl(reasonIfUnsupported,
307 input0.GetDataType(),
308 &TrueFunc<>,
309 &FalseFuncU8<>);
310}
311
312bool IsNormalizationSupportedCl(const TensorInfo& input,
313 const TensorInfo& output,
314 const NormalizationDescriptor& descriptor,
315 std::string* reasonIfUnsupported)
316{
317 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
318}
319
320bool IsOutputSupportedCl(const TensorInfo& output,
321 std::string* reasonIfUnsupported)
322{
323 return IsSupportedForDataTypeCl(reasonIfUnsupported,
324 output.GetDataType(),
325 &TrueFunc<>,
326 &TrueFunc<>);
327}
328
329bool IsPermuteSupportedCl(const TensorInfo& input,
330 const TensorInfo& output,
331 const PermuteDescriptor& descriptor,
332 std::string* reasonIfUnsupported)
333{
334 ignore_unused(input);
335 ignore_unused(output);
336 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
337}
338
339bool IsPooling2dSupportedCl(const TensorInfo& input,
340 const TensorInfo& output,
341 const Pooling2dDescriptor& descriptor,
342 std::string* reasonIfUnsupported)
343{
344 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
345}
346
347bool IsResizeBilinearSupportedCl(const TensorInfo& input,
348 std::string* reasonIfUnsupported)
349{
350 return IsSupportedForDataTypeCl(reasonIfUnsupported,
351 input.GetDataType(),
352 &TrueFunc<>,
353 &FalseFuncU8<>);
354}
355
356bool IsSoftmaxSupportedCl(const TensorInfo& input,
357 const SoftmaxDescriptor& descriptor,
358 std::string* reasonIfUnsupported)
359{
360 ignore_unused(descriptor);
361 return IsSupportedForDataTypeCl(reasonIfUnsupported,
362 input.GetDataType(),
363 &TrueFunc<>,
364 &TrueFunc<>);
365}
366
367bool IsSplitterSupportedCl(const TensorInfo& input,
368 const ViewsDescriptor& descriptor,
369 std::string* reasonIfUnsupported)
370{
371 ignore_unused(descriptor);
372 return IsSupportedForDataTypeCl(reasonIfUnsupported,
373 input.GetDataType(),
374 &TrueFunc<>,
375 &TrueFunc<>);
376}
377
378bool IsFakeQuantizationSupportedCl(const TensorInfo& input,
379 const FakeQuantizationDescriptor& descriptor,
380 std::string* reasonIfUnsupported)
381{
382 ignore_unused(input);
383 ignore_unused(descriptor);
384 return false;
385}
386
387bool IsReshapeSupportedCl(const TensorInfo& input,
388 std::string* reasonIfUnsupported)
389{
390 ignore_unused(input);
391 return true;
392}
393
394bool IsFloorSupportedCl(const TensorInfo& input,
395 const TensorInfo& output,
396 std::string* reasonIfUnsupported)
397{
398 ignore_unused(output);
399 return IsSupportedForDataTypeCl(reasonIfUnsupported,
400 input.GetDataType(),
401 &TrueFunc<>,
402 &FalseFuncU8<>);
403}
404
405}