blob: 623bdcafc9862dd228b30c0b529c666ed3d99af2 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include "LayerSupportCommon.hpp"
7
8#include "ClLayerSupport.hpp"
9#include "InternalTypes.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010#include <armnn/Descriptors.hpp>
11#include <armnn/Types.hpp>
12#include <armnn/Tensor.hpp>
13
14#include <boost/core/ignore_unused.hpp>
15
16#ifdef ARMCOMPUTECL_ENABLED
arovir019e53a352018-08-31 15:26:35 +010017#include "ClWorkloads/ClAdditionFloatWorkload.hpp"
18#include "ClWorkloads/ClActivationFloatWorkload.hpp"
19#include "ClWorkloads/ClBatchNormalizationFloatWorkload.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010020#include "ClWorkloads/ClConvertFp16ToFp32Workload.hpp"
21#include "ClWorkloads/ClConvertFp32ToFp16Workload.hpp"
surmeh013537c2c2018-05-18 16:31:43 +010022#include "ClWorkloads/ClConvolution2dBaseWorkload.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010023#include "ClWorkloads/ClDepthwiseConvolutionBaseWorkload.hpp"
Francis Murtaghe7a86a42018-08-29 12:42:10 +010024#include "ClWorkloads/ClDivisionFloatWorkload.hpp"
arovir019e53a352018-08-31 15:26:35 +010025#include "ClWorkloads/ClL2NormalizationFloatWorkload.hpp"
26#include "ClWorkloads/ClMultiplicationFloatWorkload.hpp"
27#include "ClWorkloads/ClFullyConnectedFloatWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000028#include "ClWorkloads/ClPooling2dBaseWorkload.hpp"
29#include "ClWorkloads/ClPermuteWorkload.hpp"
arovir019e53a352018-08-31 15:26:35 +010030#include "ClWorkloads/ClNormalizationFloatWorkload.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010031#include "ClWorkloads/ClSoftmaxBaseWorkload.hpp"
arovir019e53a352018-08-31 15:26:35 +010032#include "ClWorkloads/ClLstmFloatWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000033#endif
34
35using namespace boost;
36
37namespace armnn
38{
39namespace
40{
41template<unsigned int FilterSize>
42bool IsMatchingSize2d(const TensorInfo& weightInfo)
43{
telsoa01c577f2c2018-08-31 09:22:23 +010044 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +000045 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
46}
47
48template<uint32_t ValidStride>
49bool IsMatchingStride(uint32_t actualStride)
50{
51 return ValidStride == actualStride;
52}
53
54template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
55bool IsMatchingStride(uint32_t actualStride)
56{
57 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
58};
59
60bool IsClBackendSupported(std::string* reasonIfUnsupported)
61{
62#if ARMCOMPUTECL_ENABLED
63 return true;
64#else
65 if (reasonIfUnsupported != nullptr)
66 {
67 *reasonIfUnsupported = "The armnn library has been built without CL support";
68 }
69 return false;
70#endif
71}
72
73#if ARMCOMPUTECL_ENABLED
74#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
75#else
76#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
77#endif
78
79#if ARMCOMPUTECL_ENABLED
80template<class FuncType, class... Args>
81inline bool IsWorkloadSupported(FuncType&& func, std::string* reasonIfUnsupported, Args&&... args)
82{
83 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
84 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
85 if (!supported && reasonIfUnsupported)
86 {
87 *reasonIfUnsupported = aclStatus.error_description();
88 }
89 return supported;
90}
91
92#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
93 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
94#else
95#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
96 return IsClBackendSupported(reasonIfUnsupported);
97#endif
98
99} //namespace
100
telsoa01c577f2c2018-08-31 09:22:23 +0100101template<typename FloatFunc, typename Uint8Func, typename ... Params>
telsoa014fcda012018-03-09 14:13:49 +0000102bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported,
103 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100104 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000105 Uint8Func uint8FuncPtr,
106 Params&&... params)
107{
108 return IsClBackendSupported(reasonIfUnsupported) &&
109 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
110 dataType,
111 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100112 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000113 uint8FuncPtr,
114 std::forward<Params>(params)...);
115}
116
117bool IsActivationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100118 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000119 const ActivationDescriptor& descriptor,
120 std::string* reasonIfUnsupported)
121{
telsoa01c577f2c2018-08-31 09:22:23 +0100122 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
123 reasonIfUnsupported,
124 input,
125 output,
126 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000127}
128
129bool IsAdditionSupportedCl(const TensorInfo& input0,
130 const TensorInfo& input1,
131 const TensorInfo& output,
132 std::string* reasonIfUnsupported)
133{
telsoa01c577f2c2018-08-31 09:22:23 +0100134 return FORWARD_CL_LAYER_SUPPORT_FUNC(ClAdditionValidate(input0,
telsoa014fcda012018-03-09 14:13:49 +0000135 input1,
136 output,
137 reasonIfUnsupported));
138}
139
140bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100141 const TensorInfo& output,
142 const TensorInfo& mean,
143 const TensorInfo& var,
144 const TensorInfo& beta,
145 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +0000146 const BatchNormalizationDescriptor& descriptor,
147 std::string* reasonIfUnsupported)
148{
telsoa01c577f2c2018-08-31 09:22:23 +0100149 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
150 reasonIfUnsupported,
151 input,
152 output,
153 mean,
154 var,
155 beta,
156 gamma,
157 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000158}
159
160bool IsConstantSupportedCl(const TensorInfo& output,
161 std::string* reasonIfUnsupported)
162{
163 return IsSupportedForDataTypeCl(reasonIfUnsupported,
164 output.GetDataType(),
165 &TrueFunc<>,
166 &FalseFuncU8<>);
167}
168
169bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc)
170{
171 bool isSupported = false;
172
173 bool strideXIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideX);
174 bool strideXIsThree = IsMatchingStride<3>(desc.m_StrideX);
175
176 bool strideYIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideY);
177 bool strideYIsThree = IsMatchingStride<3>(desc.m_StrideY);
178
179 bool strideIsOneOrTwo = strideXIsOneOrTwo && strideYIsOneOrTwo;
180 bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
181
telsoa01c577f2c2018-08-31 09:22:23 +0100182 // 1x1 convolution with strides of 1,2,3.
telsoa014fcda012018-03-09 14:13:49 +0000183 isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
184
telsoa01c577f2c2018-08-31 09:22:23 +0100185 // 3x3 convolution with strides of 1,2.
telsoa014fcda012018-03-09 14:13:49 +0000186 isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
187
188 // 5x5 convolution with strides of 1,2
189 isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
190
telsoa01c577f2c2018-08-31 09:22:23 +0100191 //Fall back to normal convolution for the asymmetric padding case.
telsoa014fcda012018-03-09 14:13:49 +0000192 if (desc.m_PadLeft != desc.m_PadRight ||
193 desc.m_PadTop != desc.m_PadBottom)
194 {
telsoa01c577f2c2018-08-31 09:22:23 +0100195 //Direct convolution does not support asymmetric padding yet.
telsoa014fcda012018-03-09 14:13:49 +0000196 isSupported = false;
197 }
198
199 return isSupported;
200}
201
202bool IsDirectConvolution2dParamsSupportedCl(std::string* reasonIfUnsupported,
203 const Convolution2dDescriptor& parameters,
204 const TensorInfo& weightInfo)
205{
206 return IsClDirectConvolution2dSupported(weightInfo, parameters);
207}
208
209bool IsConvolution2dSupportedCl(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +0100210 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000211 const Convolution2dDescriptor& descriptor,
212 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +0100213 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000214 std::string* reasonIfUnsupported)
215{
surmeh013537c2c2018-05-18 16:31:43 +0100216 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
217 reasonIfUnsupported,
218 input,
219 output,
220 descriptor,
221 weights,
222 biases);
telsoa014fcda012018-03-09 14:13:49 +0000223}
224
225bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100226 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000227 const DepthwiseConvolution2dDescriptor& descriptor,
228 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +0100229 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000230 std::string* reasonIfUnsupported)
231{
telsoa01c577f2c2018-08-31 09:22:23 +0100232 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
233 reasonIfUnsupported,
234 input,
235 output,
236 descriptor,
237 weights,
238 biases);
telsoa014fcda012018-03-09 14:13:49 +0000239}
240
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100241bool IsDivisionSupportedCl(const TensorInfo& input0,
242 const TensorInfo& input1,
243 const TensorInfo& output,
244 std::string* reasonIfUnsupported)
245{
246 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
247 reasonIfUnsupported,
248 input0,
249 input1,
250 output);
251}
252
telsoa014fcda012018-03-09 14:13:49 +0000253bool IsFullyConnectedSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100254 const TensorInfo& output,
255 const TensorInfo& weights,
256 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000257 const FullyConnectedDescriptor& descriptor,
258 std::string* reasonIfUnsupported)
259{
telsoa01c577f2c2018-08-31 09:22:23 +0100260 // At the moment U8 is unsupported
261 if (input.GetDataType() == DataType::QuantisedAsymm8)
262 {
263 return false;
264 }
265 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
266 reasonIfUnsupported,
267 input,
268 output,
269 weights,
270 biases,
271 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000272}
273
274bool IsInputSupportedCl(const TensorInfo& input,
275 std::string* reasonIfUnsupported)
276{
277 return IsSupportedForDataTypeCl(reasonIfUnsupported,
278 input.GetDataType(),
279 &TrueFunc<>,
280 &TrueFunc<>);
281}
282
283bool IsL2NormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100284 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000285 std::string* reasonIfUnsupported)
286{
telsoa01c577f2c2018-08-31 09:22:23 +0100287 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output);
telsoa014fcda012018-03-09 14:13:49 +0000288}
289
290bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
291 const OriginsDescriptor& descriptor,
292 std::string* reasonIfUnsupported)
293{
294 ignore_unused(descriptor);
295 return IsSupportedForDataTypeCl(reasonIfUnsupported,
296 inputs[0]->GetDataType(),
297 &TrueFunc<>,
298 &FalseFuncU8<>);
299}
300
301bool IsMultiplicationSupportedCl(const TensorInfo& input0,
302 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100303 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000304 std::string* reasonIfUnsupported)
305{
telsoa01c577f2c2018-08-31 09:22:23 +0100306 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
307 reasonIfUnsupported,
308 input0,
309 input1,
310 output);
telsoa014fcda012018-03-09 14:13:49 +0000311}
312
313bool IsNormalizationSupportedCl(const TensorInfo& input,
314 const TensorInfo& output,
315 const NormalizationDescriptor& descriptor,
316 std::string* reasonIfUnsupported)
317{
318 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
319}
320
321bool IsOutputSupportedCl(const TensorInfo& output,
322 std::string* reasonIfUnsupported)
323{
324 return IsSupportedForDataTypeCl(reasonIfUnsupported,
325 output.GetDataType(),
326 &TrueFunc<>,
327 &TrueFunc<>);
328}
329
330bool IsPermuteSupportedCl(const TensorInfo& input,
331 const TensorInfo& output,
332 const PermuteDescriptor& descriptor,
333 std::string* reasonIfUnsupported)
334{
335 ignore_unused(input);
336 ignore_unused(output);
337 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
338}
339
340bool IsPooling2dSupportedCl(const TensorInfo& input,
341 const TensorInfo& output,
342 const Pooling2dDescriptor& descriptor,
343 std::string* reasonIfUnsupported)
344{
345 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
346}
347
348bool IsResizeBilinearSupportedCl(const TensorInfo& input,
349 std::string* reasonIfUnsupported)
350{
351 return IsSupportedForDataTypeCl(reasonIfUnsupported,
352 input.GetDataType(),
353 &TrueFunc<>,
354 &FalseFuncU8<>);
355}
356
357bool IsSoftmaxSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100358 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000359 const SoftmaxDescriptor& descriptor,
360 std::string* reasonIfUnsupported)
361{
362 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100363 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
telsoa014fcda012018-03-09 14:13:49 +0000364}
365
366bool IsSplitterSupportedCl(const TensorInfo& input,
367 const ViewsDescriptor& descriptor,
368 std::string* reasonIfUnsupported)
369{
370 ignore_unused(descriptor);
371 return IsSupportedForDataTypeCl(reasonIfUnsupported,
372 input.GetDataType(),
373 &TrueFunc<>,
374 &TrueFunc<>);
375}
376
377bool IsFakeQuantizationSupportedCl(const TensorInfo& input,
378 const FakeQuantizationDescriptor& descriptor,
379 std::string* reasonIfUnsupported)
380{
381 ignore_unused(input);
382 ignore_unused(descriptor);
383 return false;
384}
385
386bool IsReshapeSupportedCl(const TensorInfo& input,
387 std::string* reasonIfUnsupported)
388{
389 ignore_unused(input);
390 return true;
391}
392
393bool IsFloorSupportedCl(const TensorInfo& input,
394 const TensorInfo& output,
395 std::string* reasonIfUnsupported)
396{
397 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100398 return IsClBackendSupported(reasonIfUnsupported) &&
399 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
400 input.GetDataType(),
401 &FalseFuncF16<>,
402 &TrueFunc<>,
403 &FalseFuncU8<>);
404}
405
406bool IsLstmSupportedCl(const TensorInfo& input, const TensorInfo& outputStateIn,
407 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
408 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
409 const TensorInfo& output, const LstmDescriptor& descriptor,
410 const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
411 const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
412 const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
413 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
414 const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
415 const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
416 const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
417 const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
418 const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
419{
arovir019e53a352018-08-31 15:26:35 +0100420 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate, reasonIfUnsupported,
telsoa01c577f2c2018-08-31 09:22:23 +0100421 input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut,
422 output, descriptor, inputToForgetWeights, inputToCellWeights,
423 inputToOutputWeights, recurrentToForgetWeights,
424 recurrentToCellWeights, recurrentToOutputWeights,
425 forgetGateBias, cellBias, outputGateBias,
426 inputToInputWeights, recurrentToInputWeights,
427 cellToInputWeights, inputGateBias, projectionWeights,
428 projectionBias, cellToForgetWeights, cellToOutputWeights);
429}
430
431bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input,
432 const TensorInfo& output,
433 std::string* reasonIfUnsupported)
434{
435 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
436 reasonIfUnsupported,
437 input,
438 output,
439 reasonIfUnsupported);
440}
441
442bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input,
443 const TensorInfo& output,
444 std::string* reasonIfUnsupported)
445{
446 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
447 reasonIfUnsupported,
448 input,
449 output,
450 reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000451}
452
453}