blob: b00a218a727335a66eec111e06a9f8323e3d2f20 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include "LayerSupportCommon.hpp"
7
8#include "ClLayerSupport.hpp"
9#include "InternalTypes.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010#include <armnn/Descriptors.hpp>
11#include <armnn/Types.hpp>
12#include <armnn/Tensor.hpp>
13
14#include <boost/core/ignore_unused.hpp>
15
16#ifdef ARMCOMPUTECL_ENABLED
17#include "ClWorkloads/ClAdditionFloat32Workload.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010018#include "ClWorkloads/ClActivationFloat32Workload.hpp"
19#include "ClWorkloads/ClBatchNormalizationFloat32Workload.hpp"
20
21#include "ClWorkloads/ClConvertFp16ToFp32Workload.hpp"
22#include "ClWorkloads/ClConvertFp32ToFp16Workload.hpp"
surmeh013537c2c2018-05-18 16:31:43 +010023#include "ClWorkloads/ClConvolution2dBaseWorkload.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010024#include "ClWorkloads/ClDepthwiseConvolutionBaseWorkload.hpp"
25#include "ClWorkloads/ClL2NormalizationFloat32Workload.hpp"
26#include "ClWorkloads/ClMultiplicationFloat32Workload.hpp"
27#include "ClWorkloads/ClFullyConnectedFloat32Workload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000028#include "ClWorkloads/ClPooling2dBaseWorkload.hpp"
29#include "ClWorkloads/ClPermuteWorkload.hpp"
30#include "ClWorkloads/ClNormalizationFloat32Workload.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010031#include "ClWorkloads/ClSoftmaxBaseWorkload.hpp"
32#include "ClWorkloads/ClLstmFloat32Workload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000033#endif
34
35using namespace boost;
36
37namespace armnn
38{
39namespace
40{
41template<unsigned int FilterSize>
42bool IsMatchingSize2d(const TensorInfo& weightInfo)
43{
telsoa01c577f2c2018-08-31 09:22:23 +010044 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +000045 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
46}
47
48template<uint32_t ValidStride>
49bool IsMatchingStride(uint32_t actualStride)
50{
51 return ValidStride == actualStride;
52}
53
54template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
55bool IsMatchingStride(uint32_t actualStride)
56{
57 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
58};
59
60bool IsClBackendSupported(std::string* reasonIfUnsupported)
61{
62#if ARMCOMPUTECL_ENABLED
63 return true;
64#else
65 if (reasonIfUnsupported != nullptr)
66 {
67 *reasonIfUnsupported = "The armnn library has been built without CL support";
68 }
69 return false;
70#endif
71}
72
73#if ARMCOMPUTECL_ENABLED
74#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
75#else
76#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
77#endif
78
79#if ARMCOMPUTECL_ENABLED
80template<class FuncType, class... Args>
81inline bool IsWorkloadSupported(FuncType&& func, std::string* reasonIfUnsupported, Args&&... args)
82{
83 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
84 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
85 if (!supported && reasonIfUnsupported)
86 {
87 *reasonIfUnsupported = aclStatus.error_description();
88 }
89 return supported;
90}
91
92#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
93 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
94#else
95#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
96 return IsClBackendSupported(reasonIfUnsupported);
97#endif
98
99} //namespace
100
telsoa01c577f2c2018-08-31 09:22:23 +0100101template<typename FloatFunc, typename Uint8Func, typename ... Params>
telsoa014fcda012018-03-09 14:13:49 +0000102bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported,
103 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100104 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000105 Uint8Func uint8FuncPtr,
106 Params&&... params)
107{
108 return IsClBackendSupported(reasonIfUnsupported) &&
109 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
110 dataType,
111 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100112 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000113 uint8FuncPtr,
114 std::forward<Params>(params)...);
115}
116
117bool IsActivationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100118 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000119 const ActivationDescriptor& descriptor,
120 std::string* reasonIfUnsupported)
121{
telsoa01c577f2c2018-08-31 09:22:23 +0100122 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
123 reasonIfUnsupported,
124 input,
125 output,
126 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000127}
128
129bool IsAdditionSupportedCl(const TensorInfo& input0,
130 const TensorInfo& input1,
131 const TensorInfo& output,
132 std::string* reasonIfUnsupported)
133{
telsoa01c577f2c2018-08-31 09:22:23 +0100134 return FORWARD_CL_LAYER_SUPPORT_FUNC(ClAdditionValidate(input0,
telsoa014fcda012018-03-09 14:13:49 +0000135 input1,
136 output,
137 reasonIfUnsupported));
138}
139
140bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100141 const TensorInfo& output,
142 const TensorInfo& mean,
143 const TensorInfo& var,
144 const TensorInfo& beta,
145 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +0000146 const BatchNormalizationDescriptor& descriptor,
147 std::string* reasonIfUnsupported)
148{
telsoa01c577f2c2018-08-31 09:22:23 +0100149 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
150 reasonIfUnsupported,
151 input,
152 output,
153 mean,
154 var,
155 beta,
156 gamma,
157 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000158}
159
160bool IsConstantSupportedCl(const TensorInfo& output,
161 std::string* reasonIfUnsupported)
162{
163 return IsSupportedForDataTypeCl(reasonIfUnsupported,
164 output.GetDataType(),
165 &TrueFunc<>,
166 &FalseFuncU8<>);
167}
168
169bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc)
170{
171 bool isSupported = false;
172
173 bool strideXIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideX);
174 bool strideXIsThree = IsMatchingStride<3>(desc.m_StrideX);
175
176 bool strideYIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideY);
177 bool strideYIsThree = IsMatchingStride<3>(desc.m_StrideY);
178
179 bool strideIsOneOrTwo = strideXIsOneOrTwo && strideYIsOneOrTwo;
180 bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
181
telsoa01c577f2c2018-08-31 09:22:23 +0100182 // 1x1 convolution with strides of 1,2,3.
telsoa014fcda012018-03-09 14:13:49 +0000183 isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
184
telsoa01c577f2c2018-08-31 09:22:23 +0100185 // 3x3 convolution with strides of 1,2.
telsoa014fcda012018-03-09 14:13:49 +0000186 isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
187
188 // 5x5 convolution with strides of 1,2
189 isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
190
telsoa01c577f2c2018-08-31 09:22:23 +0100191 //Fall back to normal convolution for the asymmetric padding case.
telsoa014fcda012018-03-09 14:13:49 +0000192 if (desc.m_PadLeft != desc.m_PadRight ||
193 desc.m_PadTop != desc.m_PadBottom)
194 {
telsoa01c577f2c2018-08-31 09:22:23 +0100195 //Direct convolution does not support asymmetric padding yet.
telsoa014fcda012018-03-09 14:13:49 +0000196 isSupported = false;
197 }
198
199 return isSupported;
200}
201
202bool IsDirectConvolution2dParamsSupportedCl(std::string* reasonIfUnsupported,
203 const Convolution2dDescriptor& parameters,
204 const TensorInfo& weightInfo)
205{
206 return IsClDirectConvolution2dSupported(weightInfo, parameters);
207}
208
209bool IsConvolution2dSupportedCl(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +0100210 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000211 const Convolution2dDescriptor& descriptor,
212 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +0100213 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000214 std::string* reasonIfUnsupported)
215{
surmeh013537c2c2018-05-18 16:31:43 +0100216 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
217 reasonIfUnsupported,
218 input,
219 output,
220 descriptor,
221 weights,
222 biases);
telsoa014fcda012018-03-09 14:13:49 +0000223}
224
225bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100226 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000227 const DepthwiseConvolution2dDescriptor& descriptor,
228 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +0100229 const boost::optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000230 std::string* reasonIfUnsupported)
231{
telsoa01c577f2c2018-08-31 09:22:23 +0100232 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
233 reasonIfUnsupported,
234 input,
235 output,
236 descriptor,
237 weights,
238 biases);
telsoa014fcda012018-03-09 14:13:49 +0000239}
240
241bool IsFullyConnectedSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100242 const TensorInfo& output,
243 const TensorInfo& weights,
244 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000245 const FullyConnectedDescriptor& descriptor,
246 std::string* reasonIfUnsupported)
247{
telsoa01c577f2c2018-08-31 09:22:23 +0100248 // At the moment U8 is unsupported
249 if (input.GetDataType() == DataType::QuantisedAsymm8)
250 {
251 return false;
252 }
253 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
254 reasonIfUnsupported,
255 input,
256 output,
257 weights,
258 biases,
259 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000260}
261
262bool IsInputSupportedCl(const TensorInfo& input,
263 std::string* reasonIfUnsupported)
264{
265 return IsSupportedForDataTypeCl(reasonIfUnsupported,
266 input.GetDataType(),
267 &TrueFunc<>,
268 &TrueFunc<>);
269}
270
271bool IsL2NormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100272 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000273 std::string* reasonIfUnsupported)
274{
telsoa01c577f2c2018-08-31 09:22:23 +0100275 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output);
telsoa014fcda012018-03-09 14:13:49 +0000276}
277
278bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
279 const OriginsDescriptor& descriptor,
280 std::string* reasonIfUnsupported)
281{
282 ignore_unused(descriptor);
283 return IsSupportedForDataTypeCl(reasonIfUnsupported,
284 inputs[0]->GetDataType(),
285 &TrueFunc<>,
286 &FalseFuncU8<>);
287}
288
289bool IsMultiplicationSupportedCl(const TensorInfo& input0,
290 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100291 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000292 std::string* reasonIfUnsupported)
293{
telsoa01c577f2c2018-08-31 09:22:23 +0100294 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
295 reasonIfUnsupported,
296 input0,
297 input1,
298 output);
telsoa014fcda012018-03-09 14:13:49 +0000299}
300
301bool IsNormalizationSupportedCl(const TensorInfo& input,
302 const TensorInfo& output,
303 const NormalizationDescriptor& descriptor,
304 std::string* reasonIfUnsupported)
305{
306 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
307}
308
309bool IsOutputSupportedCl(const TensorInfo& output,
310 std::string* reasonIfUnsupported)
311{
312 return IsSupportedForDataTypeCl(reasonIfUnsupported,
313 output.GetDataType(),
314 &TrueFunc<>,
315 &TrueFunc<>);
316}
317
318bool IsPermuteSupportedCl(const TensorInfo& input,
319 const TensorInfo& output,
320 const PermuteDescriptor& descriptor,
321 std::string* reasonIfUnsupported)
322{
323 ignore_unused(input);
324 ignore_unused(output);
325 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
326}
327
328bool IsPooling2dSupportedCl(const TensorInfo& input,
329 const TensorInfo& output,
330 const Pooling2dDescriptor& descriptor,
331 std::string* reasonIfUnsupported)
332{
333 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
334}
335
336bool IsResizeBilinearSupportedCl(const TensorInfo& input,
337 std::string* reasonIfUnsupported)
338{
339 return IsSupportedForDataTypeCl(reasonIfUnsupported,
340 input.GetDataType(),
341 &TrueFunc<>,
342 &FalseFuncU8<>);
343}
344
345bool IsSoftmaxSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100346 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000347 const SoftmaxDescriptor& descriptor,
348 std::string* reasonIfUnsupported)
349{
350 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100351 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
telsoa014fcda012018-03-09 14:13:49 +0000352}
353
354bool IsSplitterSupportedCl(const TensorInfo& input,
355 const ViewsDescriptor& descriptor,
356 std::string* reasonIfUnsupported)
357{
358 ignore_unused(descriptor);
359 return IsSupportedForDataTypeCl(reasonIfUnsupported,
360 input.GetDataType(),
361 &TrueFunc<>,
362 &TrueFunc<>);
363}
364
365bool IsFakeQuantizationSupportedCl(const TensorInfo& input,
366 const FakeQuantizationDescriptor& descriptor,
367 std::string* reasonIfUnsupported)
368{
369 ignore_unused(input);
370 ignore_unused(descriptor);
371 return false;
372}
373
374bool IsReshapeSupportedCl(const TensorInfo& input,
375 std::string* reasonIfUnsupported)
376{
377 ignore_unused(input);
378 return true;
379}
380
381bool IsFloorSupportedCl(const TensorInfo& input,
382 const TensorInfo& output,
383 std::string* reasonIfUnsupported)
384{
385 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100386 return IsClBackendSupported(reasonIfUnsupported) &&
387 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
388 input.GetDataType(),
389 &FalseFuncF16<>,
390 &TrueFunc<>,
391 &FalseFuncU8<>);
392}
393
394bool IsLstmSupportedCl(const TensorInfo& input, const TensorInfo& outputStateIn,
395 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
396 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
397 const TensorInfo& output, const LstmDescriptor& descriptor,
398 const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
399 const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
400 const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
401 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
402 const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
403 const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
404 const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
405 const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
406 const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
407{
408 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloat32WorkloadValidate, reasonIfUnsupported,
409 input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut,
410 output, descriptor, inputToForgetWeights, inputToCellWeights,
411 inputToOutputWeights, recurrentToForgetWeights,
412 recurrentToCellWeights, recurrentToOutputWeights,
413 forgetGateBias, cellBias, outputGateBias,
414 inputToInputWeights, recurrentToInputWeights,
415 cellToInputWeights, inputGateBias, projectionWeights,
416 projectionBias, cellToForgetWeights, cellToOutputWeights);
417}
418
419bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input,
420 const TensorInfo& output,
421 std::string* reasonIfUnsupported)
422{
423 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
424 reasonIfUnsupported,
425 input,
426 output,
427 reasonIfUnsupported);
428}
429
430bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input,
431 const TensorInfo& output,
432 std::string* reasonIfUnsupported)
433{
434 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
435 reasonIfUnsupported,
436 input,
437 output,
438 reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000439}
440
441}