blob: 6a49a80c7fb4a0a26042c1285f4d521c7dc0ee99 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "ClLayerSupport.hpp"
arovir017c22c702018-10-09 11:16:46 +01007
David Beck3cc9a622018-10-12 10:38:31 +01008#include <armnn/Descriptors.hpp>
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +01009#include <armnn/InternalTypes.hpp>
10#include <armnn/LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
12#include <boost/core/ignore_unused.hpp>
13
14#ifdef ARMCOMPUTECL_ENABLED
David Beckac42efd2018-09-26 17:41:13 +010015#include "workloads/ClAdditionWorkload.hpp"
Nattapat Chaimanowonge06757e2018-10-11 15:39:18 +010016#include "workloads/ClActivationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010017#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
18#include "workloads/ClConvertFp16ToFp32Workload.hpp"
19#include "workloads/ClConvertFp32ToFp16Workload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010020#include "workloads/ClConvolution2dWorkload.hpp"
Matthew Benthamd8777392018-10-08 09:38:55 +010021#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010022#include "workloads/ClDivisionFloatWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010023#include "workloads/ClFullyConnectedWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010024#include "workloads/ClL2NormalizationFloatWorkload.hpp"
25#include "workloads/ClLstmFloatWorkload.hpp"
26#include "workloads/ClMultiplicationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010027#include "workloads/ClNormalizationFloatWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010028#include "workloads/ClPadWorkload.hpp"
29#include "workloads/ClPermuteWorkload.hpp"
Nattapat Chaimanowongac9e0962018-10-10 17:18:35 +010030#include "workloads/ClPooling2dWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010031#include "workloads/ClSoftmaxBaseWorkload.hpp"
32#include "workloads/ClSubtractionWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000033#endif
34
35using namespace boost;
36
37namespace armnn
38{
arovir017c22c702018-10-09 11:16:46 +010039
telsoa014fcda012018-03-09 14:13:49 +000040namespace
41{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010042
telsoa014fcda012018-03-09 14:13:49 +000043template<unsigned int FilterSize>
44bool IsMatchingSize2d(const TensorInfo& weightInfo)
45{
telsoa01c577f2c2018-08-31 09:22:23 +010046 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +000047 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
48}
49
50template<uint32_t ValidStride>
51bool IsMatchingStride(uint32_t actualStride)
52{
53 return ValidStride == actualStride;
54}
55
56template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
57bool IsMatchingStride(uint32_t actualStride)
58{
59 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010060}
telsoa014fcda012018-03-09 14:13:49 +000061
arovir01085f0a42018-10-08 14:48:19 +010062bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000063{
64#if ARMCOMPUTECL_ENABLED
65 return true;
66#else
arovir01085f0a42018-10-08 14:48:19 +010067 if (reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000068 {
arovir01085f0a42018-10-08 14:48:19 +010069 reasonIfUnsupported.value() = "The armnn library has been built without CL support";
telsoa014fcda012018-03-09 14:13:49 +000070 }
71 return false;
72#endif
73}
74
75#if ARMCOMPUTECL_ENABLED
76#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
77#else
78#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
79#endif
80
81#if ARMCOMPUTECL_ENABLED
82template<class FuncType, class... Args>
arovir01085f0a42018-10-08 14:48:19 +010083inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
telsoa014fcda012018-03-09 14:13:49 +000084{
85 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
86 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
87 if (!supported && reasonIfUnsupported)
88 {
arovir01085f0a42018-10-08 14:48:19 +010089 reasonIfUnsupported.value() = aclStatus.error_description();
telsoa014fcda012018-03-09 14:13:49 +000090 }
91 return supported;
92}
93
94#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
95 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
96#else
97#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
98 return IsClBackendSupported(reasonIfUnsupported);
99#endif
100
telsoa01c577f2c2018-08-31 09:22:23 +0100101template<typename FloatFunc, typename Uint8Func, typename ... Params>
arovir01085f0a42018-10-08 14:48:19 +0100102bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000103 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100104 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000105 Uint8Func uint8FuncPtr,
106 Params&&... params)
107{
108 return IsClBackendSupported(reasonIfUnsupported) &&
109 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
110 dataType,
111 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100112 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000113 uint8FuncPtr,
114 std::forward<Params>(params)...);
115}
116
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100117} // anonymous namespace
118
119bool ClLayerSupport::IsActivationSupported(const TensorInfo& input,
120 const TensorInfo& output,
121 const ActivationDescriptor& descriptor,
122 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000123{
telsoa01c577f2c2018-08-31 09:22:23 +0100124 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
125 reasonIfUnsupported,
126 input,
127 output,
128 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000129}
130
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100131bool ClLayerSupport::IsAdditionSupported(const TensorInfo& input0,
132 const TensorInfo& input1,
133 const TensorInfo& output,
134 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000135{
arovir01085f0a42018-10-08 14:48:19 +0100136 FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
137 reasonIfUnsupported,
138 input0,
139 input1,
140 output);
telsoa014fcda012018-03-09 14:13:49 +0000141}
142
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100143bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
144 const TensorInfo& output,
145 const TensorInfo& mean,
146 const TensorInfo& var,
147 const TensorInfo& beta,
148 const TensorInfo& gamma,
149 const BatchNormalizationDescriptor& descriptor,
150 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000151{
telsoa01c577f2c2018-08-31 09:22:23 +0100152 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
153 reasonIfUnsupported,
154 input,
155 output,
156 mean,
157 var,
158 beta,
159 gamma,
160 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000161}
162
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100163bool ClLayerSupport::IsConstantSupported(const TensorInfo& output,
164 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000165{
166 return IsSupportedForDataTypeCl(reasonIfUnsupported,
167 output.GetDataType(),
168 &TrueFunc<>,
169 &FalseFuncU8<>);
170}
171
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100172bool ClLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
173 const TensorInfo& output,
174 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000175{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100176 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
177 reasonIfUnsupported,
178 input,
179 output);
telsoa014fcda012018-03-09 14:13:49 +0000180}
181
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100182bool ClLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
183 const TensorInfo& output,
184 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000185{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100186 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
187 reasonIfUnsupported,
188 input,
189 output);
telsoa014fcda012018-03-09 14:13:49 +0000190}
191
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100192bool ClLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
193 const TensorInfo& output,
194 const Convolution2dDescriptor& descriptor,
195 const TensorInfo& weights,
196 const Optional<TensorInfo>& biases,
197 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000198{
surmeh013537c2c2018-05-18 16:31:43 +0100199 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
200 reasonIfUnsupported,
201 input,
202 output,
203 descriptor,
204 weights,
205 biases);
telsoa014fcda012018-03-09 14:13:49 +0000206}
207
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100208bool ClLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
209 const TensorInfo& output,
210 const DepthwiseConvolution2dDescriptor& descriptor,
211 const TensorInfo& weights,
212 const Optional<TensorInfo>& biases,
213 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000214{
telsoa01c577f2c2018-08-31 09:22:23 +0100215 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
216 reasonIfUnsupported,
217 input,
218 output,
219 descriptor,
220 weights,
221 biases);
telsoa014fcda012018-03-09 14:13:49 +0000222}
223
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100224bool ClLayerSupport::IsDivisionSupported(const TensorInfo& input0,
225 const TensorInfo& input1,
226 const TensorInfo& output,
227 Optional<std::string&> reasonIfUnsupported) const
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100228{
229 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
230 reasonIfUnsupported,
231 input0,
232 input1,
233 output);
234}
235
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100236bool ClLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
237 const FakeQuantizationDescriptor& descriptor,
238 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000239{
240 ignore_unused(input);
241 ignore_unused(descriptor);
arovir01085f0a42018-10-08 14:48:19 +0100242 ignore_unused(reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000243 return false;
244}
245
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100246bool ClLayerSupport::IsFloorSupported(const TensorInfo& input,
247 const TensorInfo& output,
248 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000249{
250 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100251 return IsClBackendSupported(reasonIfUnsupported) &&
252 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
253 input.GetDataType(),
254 &FalseFuncF16<>,
255 &TrueFunc<>,
256 &FalseFuncU8<>);
257}
258
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100259bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
260 const TensorInfo& output,
261 const TensorInfo& weights,
262 const TensorInfo& biases,
263 const FullyConnectedDescriptor& descriptor,
264 Optional<std::string&> reasonIfUnsupported) const
265{
266 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
267 reasonIfUnsupported,
268 input,
269 output,
270 weights,
271 biases,
272 descriptor);
273}
274
275bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
276 Optional<std::string&> reasonIfUnsupported) const
277{
278 return IsSupportedForDataTypeCl(reasonIfUnsupported,
279 input.GetDataType(),
280 &TrueFunc<>,
281 &TrueFunc<>);
282}
283
284bool ClLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
285 const TensorInfo& output,
286 const L2NormalizationDescriptor& descriptor,
287 Optional<std::string&> reasonIfUnsupported) const
288{
289 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate,
290 reasonIfUnsupported,
291 input,
292 output,
293 descriptor);
294}
295
296bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
297 const TensorInfo& outputStateIn,
298 const TensorInfo& cellStateIn,
299 const TensorInfo& scratchBuffer,
300 const TensorInfo& outputStateOut,
301 const TensorInfo& cellStateOut,
302 const TensorInfo& output,
303 const LstmDescriptor& descriptor,
304 const TensorInfo& inputToForgetWeights,
305 const TensorInfo& inputToCellWeights,
306 const TensorInfo& inputToOutputWeights,
307 const TensorInfo& recurrentToForgetWeights,
308 const TensorInfo& recurrentToCellWeights,
309 const TensorInfo& recurrentToOutputWeights,
310 const TensorInfo& forgetGateBias,
311 const TensorInfo& cellBias,
312 const TensorInfo& outputGateBias,
313 const TensorInfo* inputToInputWeights,
314 const TensorInfo* recurrentToInputWeights,
315 const TensorInfo* cellToInputWeights,
316 const TensorInfo* inputGateBias,
317 const TensorInfo* projectionWeights,
318 const TensorInfo* projectionBias,
319 const TensorInfo* cellToForgetWeights,
320 const TensorInfo* cellToOutputWeights,
321 Optional<std::string&> reasonIfUnsupported) const
telsoa01c577f2c2018-08-31 09:22:23 +0100322{
arovir01085f0a42018-10-08 14:48:19 +0100323 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
324 reasonIfUnsupported,
325 input,
326 outputStateIn,
327 cellStateIn,
328 scratchBuffer,
329 outputStateOut,
330 cellStateOut,
331 output,
332 descriptor,
333 inputToForgetWeights,
334 inputToCellWeights,
335 inputToOutputWeights,
336 recurrentToForgetWeights,
337 recurrentToCellWeights,
338 recurrentToOutputWeights,
339 forgetGateBias,
340 cellBias,
341 outputGateBias,
342 inputToInputWeights,
343 recurrentToInputWeights,
344 cellToInputWeights,
345 inputGateBias,
346 projectionWeights,
347 projectionBias,
348 cellToForgetWeights,
349 cellToOutputWeights);
telsoa01c577f2c2018-08-31 09:22:23 +0100350}
351
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100352bool ClLayerSupport::IsMeanSupported(const TensorInfo& input,
353 const TensorInfo& output,
354 const MeanDescriptor& descriptor,
355 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100356{
arovir01085f0a42018-10-08 14:48:19 +0100357 ignore_unused(input);
358 ignore_unused(output);
359 ignore_unused(descriptor);
360 ignore_unused(reasonIfUnsupported);
narpra0132b90462018-09-13 11:07:48 +0100361 return false;
362}
363
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100364bool ClLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
365 const OriginsDescriptor& descriptor,
366 Optional<std::string&> reasonIfUnsupported) const
367{
368 ignore_unused(descriptor);
369 return IsSupportedForDataTypeCl(reasonIfUnsupported,
370 inputs[0]->GetDataType(),
371 &TrueFunc<>,
372 &FalseFuncU8<>);
373}
374
375bool ClLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
376 const TensorInfo& input1,
377 const TensorInfo& output,
378 Optional<std::string&> reasonIfUnsupported) const
379{
380 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
381 reasonIfUnsupported,
382 input0,
383 input1,
384 output);
385}
386
387bool ClLayerSupport::IsNormalizationSupported(const TensorInfo& input,
388 const TensorInfo& output,
389 const NormalizationDescriptor& descriptor,
390 Optional<std::string&> reasonIfUnsupported) const
391{
392 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
393}
394
395bool ClLayerSupport::IsOutputSupported(const TensorInfo& output,
396 Optional<std::string&> reasonIfUnsupported) const
397{
398 return IsSupportedForDataTypeCl(reasonIfUnsupported,
399 output.GetDataType(),
400 &TrueFunc<>,
401 &TrueFunc<>);
402}
403
404bool ClLayerSupport::IsPadSupported(const TensorInfo& input,
405 const TensorInfo& output,
406 const PadDescriptor& descriptor,
407 Optional<std::string&> reasonIfUnsupported) const
arovir01085f0a42018-10-08 14:48:19 +0100408{
409 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate,
410 reasonIfUnsupported,
411 input,
412 output,
413 descriptor);
414}
415
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100416bool ClLayerSupport::IsPermuteSupported(const TensorInfo& input,
417 const TensorInfo& output,
418 const PermuteDescriptor& descriptor,
419 Optional<std::string&> reasonIfUnsupported) const
420{
421 ignore_unused(input);
422 ignore_unused(output);
423 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000424}
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100425
426bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
427 const TensorInfo& output,
428 const Pooling2dDescriptor& descriptor,
429 Optional<std::string&> reasonIfUnsupported) const
430{
431 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
432}
433
434bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
435 Optional<std::string&> reasonIfUnsupported) const
436{
437 ignore_unused(input);
438 ignore_unused(reasonIfUnsupported);
439 return true;
440}
441
442bool ClLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
443 Optional<std::string&> reasonIfUnsupported) const
444{
445 return IsSupportedForDataTypeCl(reasonIfUnsupported,
446 input.GetDataType(),
447 &TrueFunc<>,
448 &FalseFuncU8<>);
449}
450
451bool ClLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
452 const TensorInfo& output,
453 const SoftmaxDescriptor& descriptor,
454 Optional<std::string&> reasonIfUnsupported) const
455{
456 ignore_unused(descriptor);
457 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
458}
459
460bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
461 const ViewsDescriptor& descriptor,
462 Optional<std::string&> reasonIfUnsupported) const
463{
464 ignore_unused(descriptor);
465 return IsSupportedForDataTypeCl(reasonIfUnsupported,
466 input.GetDataType(),
467 &TrueFunc<>,
468 &TrueFunc<>);
469}
470
471bool ClLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
472 const TensorInfo& input1,
473 const TensorInfo& output,
474 Optional<std::string&> reasonIfUnsupported) const
475{
476 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
477 reasonIfUnsupported,
478 input0,
479 input1,
480 output);
481}
482
483} // namespace armnn