blob: 039f1c24f0ebc96b00abb856bbf2de1c65eb6f9a [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "ClLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "ClBackendId.hpp"
arovir017c22c702018-10-09 11:16:46 +01008
David Beck3cc9a622018-10-12 10:38:31 +01009#include <armnn/Descriptors.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
David Beck111b5d92018-11-12 14:59:37 +000013#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010014
telsoa014fcda012018-03-09 14:13:49 +000015#include <boost/core/ignore_unused.hpp>
16
17#ifdef ARMCOMPUTECL_ENABLED
David Beckac42efd2018-09-26 17:41:13 +010018#include "workloads/ClAdditionWorkload.hpp"
Nattapat Chaimanowonge06757e2018-10-11 15:39:18 +010019#include "workloads/ClActivationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010020#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
21#include "workloads/ClConvertFp16ToFp32Workload.hpp"
22#include "workloads/ClConvertFp32ToFp16Workload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010023#include "workloads/ClConvolution2dWorkload.hpp"
Matthew Benthamd8777392018-10-08 09:38:55 +010024#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010025#include "workloads/ClDivisionFloatWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010026#include "workloads/ClFullyConnectedWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010027#include "workloads/ClL2NormalizationFloatWorkload.hpp"
28#include "workloads/ClLstmFloatWorkload.hpp"
Matteo Martincigh28dcab62018-10-19 16:40:03 +010029#include "workloads/ClMeanWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010030#include "workloads/ClMultiplicationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010031#include "workloads/ClNormalizationFloatWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010032#include "workloads/ClPadWorkload.hpp"
33#include "workloads/ClPermuteWorkload.hpp"
Nattapat Chaimanowongac9e0962018-10-10 17:18:35 +010034#include "workloads/ClPooling2dWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010035#include "workloads/ClSoftmaxBaseWorkload.hpp"
36#include "workloads/ClSubtractionWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000037#endif
38
39using namespace boost;
40
41namespace armnn
42{
arovir017c22c702018-10-09 11:16:46 +010043
telsoa014fcda012018-03-09 14:13:49 +000044namespace
45{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010046
telsoa014fcda012018-03-09 14:13:49 +000047template<unsigned int FilterSize>
48bool IsMatchingSize2d(const TensorInfo& weightInfo)
49{
telsoa01c577f2c2018-08-31 09:22:23 +010050 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +000051 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
52}
53
54template<uint32_t ValidStride>
55bool IsMatchingStride(uint32_t actualStride)
56{
57 return ValidStride == actualStride;
58}
59
60template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
61bool IsMatchingStride(uint32_t actualStride)
62{
63 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010064}
telsoa014fcda012018-03-09 14:13:49 +000065
arovir01085f0a42018-10-08 14:48:19 +010066bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000067{
68#if ARMCOMPUTECL_ENABLED
69 return true;
70#else
arovir01085f0a42018-10-08 14:48:19 +010071 if (reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000072 {
arovir01085f0a42018-10-08 14:48:19 +010073 reasonIfUnsupported.value() = "The armnn library has been built without CL support";
telsoa014fcda012018-03-09 14:13:49 +000074 }
75 return false;
76#endif
77}
78
79#if ARMCOMPUTECL_ENABLED
80#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
81#else
82#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
83#endif
84
85#if ARMCOMPUTECL_ENABLED
86template<class FuncType, class... Args>
arovir01085f0a42018-10-08 14:48:19 +010087inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
telsoa014fcda012018-03-09 14:13:49 +000088{
89 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
90 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
91 if (!supported && reasonIfUnsupported)
92 {
arovir01085f0a42018-10-08 14:48:19 +010093 reasonIfUnsupported.value() = aclStatus.error_description();
telsoa014fcda012018-03-09 14:13:49 +000094 }
95 return supported;
96}
97
98#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
99 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
100#else
101#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
102 return IsClBackendSupported(reasonIfUnsupported);
103#endif
104
telsoa01c577f2c2018-08-31 09:22:23 +0100105template<typename FloatFunc, typename Uint8Func, typename ... Params>
arovir01085f0a42018-10-08 14:48:19 +0100106bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000107 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100108 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000109 Uint8Func uint8FuncPtr,
110 Params&&... params)
111{
112 return IsClBackendSupported(reasonIfUnsupported) &&
113 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
114 dataType,
115 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100116 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000117 uint8FuncPtr,
118 std::forward<Params>(params)...);
119}
120
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100121} // anonymous namespace
122
123bool ClLayerSupport::IsActivationSupported(const TensorInfo& input,
124 const TensorInfo& output,
125 const ActivationDescriptor& descriptor,
126 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000127{
telsoa01c577f2c2018-08-31 09:22:23 +0100128 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
129 reasonIfUnsupported,
130 input,
131 output,
132 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000133}
134
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100135bool ClLayerSupport::IsAdditionSupported(const TensorInfo& input0,
136 const TensorInfo& input1,
137 const TensorInfo& output,
138 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000139{
arovir01085f0a42018-10-08 14:48:19 +0100140 FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
141 reasonIfUnsupported,
142 input0,
143 input1,
144 output);
telsoa014fcda012018-03-09 14:13:49 +0000145}
146
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100147bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
148 const TensorInfo& output,
149 const TensorInfo& mean,
150 const TensorInfo& var,
151 const TensorInfo& beta,
152 const TensorInfo& gamma,
153 const BatchNormalizationDescriptor& descriptor,
154 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000155{
telsoa01c577f2c2018-08-31 09:22:23 +0100156 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
157 reasonIfUnsupported,
158 input,
159 output,
160 mean,
161 var,
162 beta,
163 gamma,
164 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000165}
166
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100167bool ClLayerSupport::IsConstantSupported(const TensorInfo& output,
168 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000169{
170 return IsSupportedForDataTypeCl(reasonIfUnsupported,
171 output.GetDataType(),
172 &TrueFunc<>,
173 &FalseFuncU8<>);
174}
175
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100176bool ClLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
177 const TensorInfo& output,
178 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000179{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100180 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
181 reasonIfUnsupported,
182 input,
183 output);
telsoa014fcda012018-03-09 14:13:49 +0000184}
185
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100186bool ClLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
187 const TensorInfo& output,
188 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000189{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100190 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
191 reasonIfUnsupported,
192 input,
193 output);
telsoa014fcda012018-03-09 14:13:49 +0000194}
195
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100196bool ClLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
197 const TensorInfo& output,
198 const Convolution2dDescriptor& descriptor,
199 const TensorInfo& weights,
200 const Optional<TensorInfo>& biases,
201 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000202{
surmeh013537c2c2018-05-18 16:31:43 +0100203 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
204 reasonIfUnsupported,
205 input,
206 output,
207 descriptor,
208 weights,
209 biases);
telsoa014fcda012018-03-09 14:13:49 +0000210}
211
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100212bool ClLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
213 const TensorInfo& output,
214 const DepthwiseConvolution2dDescriptor& descriptor,
215 const TensorInfo& weights,
216 const Optional<TensorInfo>& biases,
217 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000218{
telsoa01c577f2c2018-08-31 09:22:23 +0100219 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
220 reasonIfUnsupported,
221 input,
222 output,
223 descriptor,
224 weights,
225 biases);
telsoa014fcda012018-03-09 14:13:49 +0000226}
227
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100228bool ClLayerSupport::IsDivisionSupported(const TensorInfo& input0,
229 const TensorInfo& input1,
230 const TensorInfo& output,
231 Optional<std::string&> reasonIfUnsupported) const
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100232{
233 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
234 reasonIfUnsupported,
235 input0,
236 input1,
237 output);
238}
239
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100240bool ClLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
241 const FakeQuantizationDescriptor& descriptor,
242 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000243{
244 ignore_unused(input);
245 ignore_unused(descriptor);
arovir01085f0a42018-10-08 14:48:19 +0100246 ignore_unused(reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000247 return false;
248}
249
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100250bool ClLayerSupport::IsFloorSupported(const TensorInfo& input,
251 const TensorInfo& output,
252 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000253{
254 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100255 return IsClBackendSupported(reasonIfUnsupported) &&
256 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
257 input.GetDataType(),
258 &FalseFuncF16<>,
259 &TrueFunc<>,
260 &FalseFuncU8<>);
261}
262
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100263bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
264 const TensorInfo& output,
265 const TensorInfo& weights,
266 const TensorInfo& biases,
267 const FullyConnectedDescriptor& descriptor,
268 Optional<std::string&> reasonIfUnsupported) const
269{
270 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
271 reasonIfUnsupported,
272 input,
273 output,
274 weights,
275 biases,
276 descriptor);
277}
278
279bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
280 Optional<std::string&> reasonIfUnsupported) const
281{
282 return IsSupportedForDataTypeCl(reasonIfUnsupported,
283 input.GetDataType(),
284 &TrueFunc<>,
285 &TrueFunc<>);
286}
287
288bool ClLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
289 const TensorInfo& output,
290 const L2NormalizationDescriptor& descriptor,
291 Optional<std::string&> reasonIfUnsupported) const
292{
293 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate,
294 reasonIfUnsupported,
295 input,
296 output,
297 descriptor);
298}
299
300bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
301 const TensorInfo& outputStateIn,
302 const TensorInfo& cellStateIn,
303 const TensorInfo& scratchBuffer,
304 const TensorInfo& outputStateOut,
305 const TensorInfo& cellStateOut,
306 const TensorInfo& output,
307 const LstmDescriptor& descriptor,
308 const TensorInfo& inputToForgetWeights,
309 const TensorInfo& inputToCellWeights,
310 const TensorInfo& inputToOutputWeights,
311 const TensorInfo& recurrentToForgetWeights,
312 const TensorInfo& recurrentToCellWeights,
313 const TensorInfo& recurrentToOutputWeights,
314 const TensorInfo& forgetGateBias,
315 const TensorInfo& cellBias,
316 const TensorInfo& outputGateBias,
317 const TensorInfo* inputToInputWeights,
318 const TensorInfo* recurrentToInputWeights,
319 const TensorInfo* cellToInputWeights,
320 const TensorInfo* inputGateBias,
321 const TensorInfo* projectionWeights,
322 const TensorInfo* projectionBias,
323 const TensorInfo* cellToForgetWeights,
324 const TensorInfo* cellToOutputWeights,
325 Optional<std::string&> reasonIfUnsupported) const
telsoa01c577f2c2018-08-31 09:22:23 +0100326{
arovir01085f0a42018-10-08 14:48:19 +0100327 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
328 reasonIfUnsupported,
329 input,
330 outputStateIn,
331 cellStateIn,
332 scratchBuffer,
333 outputStateOut,
334 cellStateOut,
335 output,
336 descriptor,
337 inputToForgetWeights,
338 inputToCellWeights,
339 inputToOutputWeights,
340 recurrentToForgetWeights,
341 recurrentToCellWeights,
342 recurrentToOutputWeights,
343 forgetGateBias,
344 cellBias,
345 outputGateBias,
346 inputToInputWeights,
347 recurrentToInputWeights,
348 cellToInputWeights,
349 inputGateBias,
350 projectionWeights,
351 projectionBias,
352 cellToForgetWeights,
353 cellToOutputWeights);
telsoa01c577f2c2018-08-31 09:22:23 +0100354}
355
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100356bool ClLayerSupport::IsMeanSupported(const TensorInfo& input,
357 const TensorInfo& output,
358 const MeanDescriptor& descriptor,
359 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100360{
Matteo Martincigh28dcab62018-10-19 16:40:03 +0100361 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMeanValidate,
362 reasonIfUnsupported,
363 input,
364 output,
365 descriptor);
narpra0132b90462018-09-13 11:07:48 +0100366}
367
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100368bool ClLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
369 const OriginsDescriptor& descriptor,
370 Optional<std::string&> reasonIfUnsupported) const
371{
372 ignore_unused(descriptor);
373 return IsSupportedForDataTypeCl(reasonIfUnsupported,
374 inputs[0]->GetDataType(),
375 &TrueFunc<>,
376 &FalseFuncU8<>);
377}
378
379bool ClLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
380 const TensorInfo& input1,
381 const TensorInfo& output,
382 Optional<std::string&> reasonIfUnsupported) const
383{
384 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
385 reasonIfUnsupported,
386 input0,
387 input1,
388 output);
389}
390
391bool ClLayerSupport::IsNormalizationSupported(const TensorInfo& input,
392 const TensorInfo& output,
393 const NormalizationDescriptor& descriptor,
394 Optional<std::string&> reasonIfUnsupported) const
395{
396 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
397}
398
399bool ClLayerSupport::IsOutputSupported(const TensorInfo& output,
400 Optional<std::string&> reasonIfUnsupported) const
401{
402 return IsSupportedForDataTypeCl(reasonIfUnsupported,
403 output.GetDataType(),
404 &TrueFunc<>,
405 &TrueFunc<>);
406}
407
408bool ClLayerSupport::IsPadSupported(const TensorInfo& input,
409 const TensorInfo& output,
410 const PadDescriptor& descriptor,
411 Optional<std::string&> reasonIfUnsupported) const
arovir01085f0a42018-10-08 14:48:19 +0100412{
413 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate,
414 reasonIfUnsupported,
415 input,
416 output,
417 descriptor);
418}
419
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100420bool ClLayerSupport::IsPermuteSupported(const TensorInfo& input,
421 const TensorInfo& output,
422 const PermuteDescriptor& descriptor,
423 Optional<std::string&> reasonIfUnsupported) const
424{
425 ignore_unused(input);
426 ignore_unused(output);
427 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000428}
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100429
430bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
431 const TensorInfo& output,
432 const Pooling2dDescriptor& descriptor,
433 Optional<std::string&> reasonIfUnsupported) const
434{
435 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
436}
437
438bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
439 Optional<std::string&> reasonIfUnsupported) const
440{
441 ignore_unused(input);
442 ignore_unused(reasonIfUnsupported);
443 return true;
444}
445
446bool ClLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
447 Optional<std::string&> reasonIfUnsupported) const
448{
449 return IsSupportedForDataTypeCl(reasonIfUnsupported,
450 input.GetDataType(),
451 &TrueFunc<>,
452 &FalseFuncU8<>);
453}
454
455bool ClLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
456 const TensorInfo& output,
457 const SoftmaxDescriptor& descriptor,
458 Optional<std::string&> reasonIfUnsupported) const
459{
460 ignore_unused(descriptor);
461 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
462}
463
464bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
465 const ViewsDescriptor& descriptor,
466 Optional<std::string&> reasonIfUnsupported) const
467{
468 ignore_unused(descriptor);
469 return IsSupportedForDataTypeCl(reasonIfUnsupported,
470 input.GetDataType(),
471 &TrueFunc<>,
472 &TrueFunc<>);
473}
474
475bool ClLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
476 const TensorInfo& input1,
477 const TensorInfo& output,
478 Optional<std::string&> reasonIfUnsupported) const
479{
480 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
481 reasonIfUnsupported,
482 input0,
483 input1,
484 output);
485}
486
487} // namespace armnn