blob: e23c70ec302fda6e4b7eb99382bdcee72508acfa [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "LayerSupportCommon.hpp"
7
8#include "ClLayerSupport.hpp"
9#include "InternalTypes.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010#include <armnn/Descriptors.hpp>
11#include <armnn/Types.hpp>
12#include <armnn/Tensor.hpp>
13
14#include <boost/core/ignore_unused.hpp>
15
16#ifdef ARMCOMPUTECL_ENABLED
David Beckac42efd2018-09-26 17:41:13 +010017#include "workloads/ClAdditionWorkload.hpp"
18#include "workloads/ClActivationFloatWorkload.hpp"
19#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
20#include "workloads/ClConvertFp16ToFp32Workload.hpp"
21#include "workloads/ClConvertFp32ToFp16Workload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010022#include "workloads/ClConvolution2dWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010023#include "workloads/ClDepthwiseConvolutionBaseWorkload.hpp"
24#include "workloads/ClDivisionFloatWorkload.hpp"
25#include "workloads/ClL2NormalizationFloatWorkload.hpp"
Matthew Benthame2ec3302018-10-01 11:32:48 +010026#include "workloads/ClMultiplicationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010027#include "workloads/ClFullyConnectedWorkload.hpp"
28#include "workloads/ClPadWorkload.hpp"
29#include "workloads/ClPooling2dBaseWorkload.hpp"
30#include "workloads/ClPermuteWorkload.hpp"
31#include "workloads/ClNormalizationFloatWorkload.hpp"
32#include "workloads/ClSoftmaxBaseWorkload.hpp"
33#include "workloads/ClSubtractionWorkload.hpp"
34#include "workloads/ClLstmFloatWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000035#endif
36
37using namespace boost;
38
39namespace armnn
40{
41namespace
42{
43template<unsigned int FilterSize>
44bool IsMatchingSize2d(const TensorInfo& weightInfo)
45{
telsoa01c577f2c2018-08-31 09:22:23 +010046 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +000047 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
48}
49
50template<uint32_t ValidStride>
51bool IsMatchingStride(uint32_t actualStride)
52{
53 return ValidStride == actualStride;
54}
55
56template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
57bool IsMatchingStride(uint32_t actualStride)
58{
59 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
60};
61
62bool IsClBackendSupported(std::string* reasonIfUnsupported)
63{
64#if ARMCOMPUTECL_ENABLED
65 return true;
66#else
67 if (reasonIfUnsupported != nullptr)
68 {
69 *reasonIfUnsupported = "The armnn library has been built without CL support";
70 }
71 return false;
72#endif
73}
74
75#if ARMCOMPUTECL_ENABLED
76#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
77#else
78#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
79#endif
80
81#if ARMCOMPUTECL_ENABLED
82template<class FuncType, class... Args>
83inline bool IsWorkloadSupported(FuncType&& func, std::string* reasonIfUnsupported, Args&&... args)
84{
85 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
86 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
87 if (!supported && reasonIfUnsupported)
88 {
89 *reasonIfUnsupported = aclStatus.error_description();
90 }
91 return supported;
92}
93
94#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
95 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
96#else
97#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
98 return IsClBackendSupported(reasonIfUnsupported);
99#endif
100
101} //namespace
102
telsoa01c577f2c2018-08-31 09:22:23 +0100103template<typename FloatFunc, typename Uint8Func, typename ... Params>
telsoa014fcda012018-03-09 14:13:49 +0000104bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported,
105 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100106 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000107 Uint8Func uint8FuncPtr,
108 Params&&... params)
109{
110 return IsClBackendSupported(reasonIfUnsupported) &&
111 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
112 dataType,
113 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100114 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000115 uint8FuncPtr,
116 std::forward<Params>(params)...);
117}
118
119bool IsActivationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100120 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000121 const ActivationDescriptor& descriptor,
122 std::string* reasonIfUnsupported)
123{
telsoa01c577f2c2018-08-31 09:22:23 +0100124 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
125 reasonIfUnsupported,
126 input,
127 output,
128 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000129}
130
131bool IsAdditionSupportedCl(const TensorInfo& input0,
132 const TensorInfo& input1,
133 const TensorInfo& output,
134 std::string* reasonIfUnsupported)
135{
telsoa01c577f2c2018-08-31 09:22:23 +0100136 return FORWARD_CL_LAYER_SUPPORT_FUNC(ClAdditionValidate(input0,
telsoa014fcda012018-03-09 14:13:49 +0000137 input1,
138 output,
139 reasonIfUnsupported));
140}
141
142bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100143 const TensorInfo& output,
144 const TensorInfo& mean,
145 const TensorInfo& var,
146 const TensorInfo& beta,
147 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +0000148 const BatchNormalizationDescriptor& descriptor,
149 std::string* reasonIfUnsupported)
150{
telsoa01c577f2c2018-08-31 09:22:23 +0100151 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
152 reasonIfUnsupported,
153 input,
154 output,
155 mean,
156 var,
157 beta,
158 gamma,
159 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000160}
161
162bool IsConstantSupportedCl(const TensorInfo& output,
163 std::string* reasonIfUnsupported)
164{
165 return IsSupportedForDataTypeCl(reasonIfUnsupported,
166 output.GetDataType(),
167 &TrueFunc<>,
168 &FalseFuncU8<>);
169}
170
171bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc)
172{
173 bool isSupported = false;
174
175 bool strideXIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideX);
176 bool strideXIsThree = IsMatchingStride<3>(desc.m_StrideX);
177
178 bool strideYIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideY);
179 bool strideYIsThree = IsMatchingStride<3>(desc.m_StrideY);
180
181 bool strideIsOneOrTwo = strideXIsOneOrTwo && strideYIsOneOrTwo;
182 bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
183
telsoa01c577f2c2018-08-31 09:22:23 +0100184 // 1x1 convolution with strides of 1,2,3.
telsoa014fcda012018-03-09 14:13:49 +0000185 isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
186
telsoa01c577f2c2018-08-31 09:22:23 +0100187 // 3x3 convolution with strides of 1,2.
telsoa014fcda012018-03-09 14:13:49 +0000188 isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
189
190 // 5x5 convolution with strides of 1,2
191 isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
192
telsoa01c577f2c2018-08-31 09:22:23 +0100193 //Fall back to normal convolution for the asymmetric padding case.
telsoa014fcda012018-03-09 14:13:49 +0000194 if (desc.m_PadLeft != desc.m_PadRight ||
195 desc.m_PadTop != desc.m_PadBottom)
196 {
telsoa01c577f2c2018-08-31 09:22:23 +0100197 //Direct convolution does not support asymmetric padding yet.
telsoa014fcda012018-03-09 14:13:49 +0000198 isSupported = false;
199 }
200
201 return isSupported;
202}
203
204bool IsDirectConvolution2dParamsSupportedCl(std::string* reasonIfUnsupported,
205 const Convolution2dDescriptor& parameters,
206 const TensorInfo& weightInfo)
207{
208 return IsClDirectConvolution2dSupported(weightInfo, parameters);
209}
210
211bool IsConvolution2dSupportedCl(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +0100212 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000213 const Convolution2dDescriptor& descriptor,
214 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +0100215 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000216 std::string* reasonIfUnsupported)
217{
surmeh013537c2c2018-05-18 16:31:43 +0100218 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
219 reasonIfUnsupported,
220 input,
221 output,
222 descriptor,
223 weights,
224 biases);
telsoa014fcda012018-03-09 14:13:49 +0000225}
226
227bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100228 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000229 const DepthwiseConvolution2dDescriptor& descriptor,
230 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +0100231 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000232 std::string* reasonIfUnsupported)
233{
telsoa01c577f2c2018-08-31 09:22:23 +0100234 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
235 reasonIfUnsupported,
236 input,
237 output,
238 descriptor,
239 weights,
240 biases);
telsoa014fcda012018-03-09 14:13:49 +0000241}
242
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100243bool IsDivisionSupportedCl(const TensorInfo& input0,
244 const TensorInfo& input1,
245 const TensorInfo& output,
246 std::string* reasonIfUnsupported)
247{
248 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
249 reasonIfUnsupported,
250 input0,
251 input1,
252 output);
253}
254
David Beckc2044fe2018-09-05 15:00:38 +0100255bool IsSubtractionSupportedCl(const TensorInfo& input0,
256 const TensorInfo& input1,
257 const TensorInfo& output,
258 std::string* reasonIfUnsupported)
259{
David Beck4a8692c2018-09-07 16:19:24 +0100260 return FORWARD_CL_LAYER_SUPPORT_FUNC(ClSubtractionValidate(input0,
261 input1,
262 output,
263 reasonIfUnsupported));
David Beckc2044fe2018-09-05 15:00:38 +0100264}
265
telsoa014fcda012018-03-09 14:13:49 +0000266bool IsFullyConnectedSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100267 const TensorInfo& output,
268 const TensorInfo& weights,
269 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000270 const FullyConnectedDescriptor& descriptor,
271 std::string* reasonIfUnsupported)
272{
telsoa01c577f2c2018-08-31 09:22:23 +0100273 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
274 reasonIfUnsupported,
275 input,
276 output,
277 weights,
278 biases,
279 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000280}
281
282bool IsInputSupportedCl(const TensorInfo& input,
283 std::string* reasonIfUnsupported)
284{
285 return IsSupportedForDataTypeCl(reasonIfUnsupported,
286 input.GetDataType(),
287 &TrueFunc<>,
288 &TrueFunc<>);
289}
290
291bool IsL2NormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100292 const TensorInfo& output,
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100293 const L2NormalizationDescriptor& descriptor,
telsoa014fcda012018-03-09 14:13:49 +0000294 std::string* reasonIfUnsupported)
295{
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100296 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000297}
298
299bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
300 const OriginsDescriptor& descriptor,
301 std::string* reasonIfUnsupported)
302{
303 ignore_unused(descriptor);
304 return IsSupportedForDataTypeCl(reasonIfUnsupported,
305 inputs[0]->GetDataType(),
306 &TrueFunc<>,
307 &FalseFuncU8<>);
308}
309
310bool IsMultiplicationSupportedCl(const TensorInfo& input0,
311 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100312 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000313 std::string* reasonIfUnsupported)
314{
telsoa01c577f2c2018-08-31 09:22:23 +0100315 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
316 reasonIfUnsupported,
317 input0,
318 input1,
319 output);
telsoa014fcda012018-03-09 14:13:49 +0000320}
321
322bool IsNormalizationSupportedCl(const TensorInfo& input,
323 const TensorInfo& output,
324 const NormalizationDescriptor& descriptor,
325 std::string* reasonIfUnsupported)
326{
327 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
328}
329
330bool IsOutputSupportedCl(const TensorInfo& output,
331 std::string* reasonIfUnsupported)
332{
333 return IsSupportedForDataTypeCl(reasonIfUnsupported,
334 output.GetDataType(),
335 &TrueFunc<>,
336 &TrueFunc<>);
337}
338
jimfly01e1fa50c2018-09-21 12:09:51 +0100339bool IsPadSupportedCl(const TensorInfo& input,
340 const TensorInfo& output,
341 const PadDescriptor& descriptor,
342 std::string* reasonIfUnsupported)
343{
344 return FORWARD_CL_LAYER_SUPPORT_FUNC(ClPadValidate(input, output, descriptor, reasonIfUnsupported));
345}
346
telsoa014fcda012018-03-09 14:13:49 +0000347bool IsPermuteSupportedCl(const TensorInfo& input,
348 const TensorInfo& output,
349 const PermuteDescriptor& descriptor,
350 std::string* reasonIfUnsupported)
351{
352 ignore_unused(input);
353 ignore_unused(output);
354 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
355}
356
357bool IsPooling2dSupportedCl(const TensorInfo& input,
358 const TensorInfo& output,
359 const Pooling2dDescriptor& descriptor,
360 std::string* reasonIfUnsupported)
361{
362 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
363}
364
365bool IsResizeBilinearSupportedCl(const TensorInfo& input,
366 std::string* reasonIfUnsupported)
367{
368 return IsSupportedForDataTypeCl(reasonIfUnsupported,
369 input.GetDataType(),
370 &TrueFunc<>,
371 &FalseFuncU8<>);
372}
373
374bool IsSoftmaxSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100375 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000376 const SoftmaxDescriptor& descriptor,
377 std::string* reasonIfUnsupported)
378{
379 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100380 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
telsoa014fcda012018-03-09 14:13:49 +0000381}
382
383bool IsSplitterSupportedCl(const TensorInfo& input,
384 const ViewsDescriptor& descriptor,
385 std::string* reasonIfUnsupported)
386{
387 ignore_unused(descriptor);
388 return IsSupportedForDataTypeCl(reasonIfUnsupported,
389 input.GetDataType(),
390 &TrueFunc<>,
391 &TrueFunc<>);
392}
393
394bool IsFakeQuantizationSupportedCl(const TensorInfo& input,
395 const FakeQuantizationDescriptor& descriptor,
396 std::string* reasonIfUnsupported)
397{
398 ignore_unused(input);
399 ignore_unused(descriptor);
400 return false;
401}
402
403bool IsReshapeSupportedCl(const TensorInfo& input,
404 std::string* reasonIfUnsupported)
405{
406 ignore_unused(input);
407 return true;
408}
409
410bool IsFloorSupportedCl(const TensorInfo& input,
411 const TensorInfo& output,
412 std::string* reasonIfUnsupported)
413{
414 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100415 return IsClBackendSupported(reasonIfUnsupported) &&
416 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
417 input.GetDataType(),
418 &FalseFuncF16<>,
419 &TrueFunc<>,
420 &FalseFuncU8<>);
421}
422
423bool IsLstmSupportedCl(const TensorInfo& input, const TensorInfo& outputStateIn,
424 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
425 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
426 const TensorInfo& output, const LstmDescriptor& descriptor,
427 const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
428 const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
429 const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
430 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
431 const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
432 const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
433 const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
434 const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
435 const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
436{
arovir019e53a352018-08-31 15:26:35 +0100437 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate, reasonIfUnsupported,
telsoa01c577f2c2018-08-31 09:22:23 +0100438 input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut,
439 output, descriptor, inputToForgetWeights, inputToCellWeights,
440 inputToOutputWeights, recurrentToForgetWeights,
441 recurrentToCellWeights, recurrentToOutputWeights,
442 forgetGateBias, cellBias, outputGateBias,
443 inputToInputWeights, recurrentToInputWeights,
444 cellToInputWeights, inputGateBias, projectionWeights,
445 projectionBias, cellToForgetWeights, cellToOutputWeights);
446}
447
448bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input,
449 const TensorInfo& output,
450 std::string* reasonIfUnsupported)
451{
452 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
453 reasonIfUnsupported,
454 input,
455 output,
456 reasonIfUnsupported);
457}
458
459bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input,
460 const TensorInfo& output,
461 std::string* reasonIfUnsupported)
462{
463 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
464 reasonIfUnsupported,
465 input,
466 output,
467 reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000468}
469
narpra0132b90462018-09-13 11:07:48 +0100470bool IsMeanSupportedCl(const TensorInfo& input,
471 const TensorInfo& output,
472 const MeanDescriptor& descriptor,
473 std::string* reasonIfUnsupported)
474{
475 return false;
476}
477
telsoa014fcda012018-03-09 14:13:49 +0000478}