blob: dfac28989cd94943e17583fdbe984b67ddb81126 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "ClLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "ClBackendId.hpp"
arovir017c22c702018-10-09 11:16:46 +01008
David Beck3cc9a622018-10-12 10:38:31 +01009#include <armnn/Descriptors.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
David Beck111b5d92018-11-12 14:59:37 +000013#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010014
telsoa014fcda012018-03-09 14:13:49 +000015#include <boost/core/ignore_unused.hpp>
16
Matteo Martincighd95e9062019-01-31 15:35:59 +000017#if defined(ARMCOMPUTECL_ENABLED)
Narumol Prangnawarat74135832019-05-23 15:07:33 +010018#include <aclCommon/ArmComputeUtils.hpp>
David Beckac42efd2018-09-26 17:41:13 +010019#include "workloads/ClAdditionWorkload.hpp"
Nattapat Chaimanowonge06757e2018-10-11 15:39:18 +010020#include "workloads/ClActivationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010021#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
Mike Kelly831faed2018-11-28 11:52:08 +000022#include "workloads/ClBatchToSpaceNdWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010023#include "workloads/ClConvertFp16ToFp32Workload.hpp"
24#include "workloads/ClConvertFp32ToFp16Workload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010025#include "workloads/ClConvolution2dWorkload.hpp"
Matthew Benthamd8777392018-10-08 09:38:55 +010026#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010027#include "workloads/ClDivisionFloatWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010028#include "workloads/ClFullyConnectedWorkload.hpp"
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +000029#include "workloads/ClGreaterWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010030#include "workloads/ClL2NormalizationFloatWorkload.hpp"
31#include "workloads/ClLstmFloatWorkload.hpp"
keidav01a959ee52018-12-19 10:04:58 +000032#include "workloads/ClMaximumWorkload.hpp"
Matteo Martincigh28dcab62018-10-19 16:40:03 +010033#include "workloads/ClMeanWorkload.hpp"
Jim Flynn69059412019-05-17 13:03:57 +010034#include "workloads/ClConcatWorkload.hpp"
saoste019292aa32019-01-08 13:55:59 +000035#include "workloads/ClMinimumWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010036#include "workloads/ClMultiplicationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010037#include "workloads/ClNormalizationFloatWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010038#include "workloads/ClPadWorkload.hpp"
39#include "workloads/ClPermuteWorkload.hpp"
Nattapat Chaimanowongac9e0962018-10-10 17:18:35 +010040#include "workloads/ClPooling2dWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010041#include "workloads/ClSoftmaxBaseWorkload.hpp"
Sadik Armaganf4464322018-12-20 16:19:12 +000042#include "workloads/ClSpaceToBatchNdWorkload.hpp"
Narumol Prangnawarat74135832019-05-23 15:07:33 +010043#include "workloads/ClSplitterWorkload.hpp"
keidav01d74dc912018-12-10 18:16:07 +000044#include "workloads/ClStridedSliceWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010045#include "workloads/ClSubtractionWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000046#endif
47
48using namespace boost;
49
50namespace armnn
51{
arovir017c22c702018-10-09 11:16:46 +010052
telsoa014fcda012018-03-09 14:13:49 +000053namespace
54{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010055
telsoa014fcda012018-03-09 14:13:49 +000056template<unsigned int FilterSize>
57bool IsMatchingSize2d(const TensorInfo& weightInfo)
58{
telsoa01c577f2c2018-08-31 09:22:23 +010059 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +000060 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
61}
62
63template<uint32_t ValidStride>
64bool IsMatchingStride(uint32_t actualStride)
65{
66 return ValidStride == actualStride;
67}
68
69template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
70bool IsMatchingStride(uint32_t actualStride)
71{
72 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010073}
telsoa014fcda012018-03-09 14:13:49 +000074
arovir01085f0a42018-10-08 14:48:19 +010075bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000076{
Matteo Martincighd95e9062019-01-31 15:35:59 +000077#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000078 return true;
79#else
arovir01085f0a42018-10-08 14:48:19 +010080 if (reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000081 {
arovir01085f0a42018-10-08 14:48:19 +010082 reasonIfUnsupported.value() = "The armnn library has been built without CL support";
telsoa014fcda012018-03-09 14:13:49 +000083 }
84 return false;
85#endif
86}
87
Matteo Martincighd95e9062019-01-31 15:35:59 +000088#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000089#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
90#else
91#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
92#endif
93
Matteo Martincighd95e9062019-01-31 15:35:59 +000094#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000095template<class FuncType, class... Args>
arovir01085f0a42018-10-08 14:48:19 +010096inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
telsoa014fcda012018-03-09 14:13:49 +000097{
98 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
99 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
100 if (!supported && reasonIfUnsupported)
101 {
arovir01085f0a42018-10-08 14:48:19 +0100102 reasonIfUnsupported.value() = aclStatus.error_description();
telsoa014fcda012018-03-09 14:13:49 +0000103 }
104 return supported;
105}
106
107#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
108 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
109#else
110#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
111 return IsClBackendSupported(reasonIfUnsupported);
112#endif
113
telsoa01c577f2c2018-08-31 09:22:23 +0100114template<typename FloatFunc, typename Uint8Func, typename ... Params>
arovir01085f0a42018-10-08 14:48:19 +0100115bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000116 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100117 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000118 Uint8Func uint8FuncPtr,
119 Params&&... params)
120{
121 return IsClBackendSupported(reasonIfUnsupported) &&
122 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
123 dataType,
124 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100125 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000126 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +0000127 &FalseFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000128 &FalseFunc<>,
telsoa014fcda012018-03-09 14:13:49 +0000129 std::forward<Params>(params)...);
130}
131
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100132} // anonymous namespace
133
134bool ClLayerSupport::IsActivationSupported(const TensorInfo& input,
135 const TensorInfo& output,
136 const ActivationDescriptor& descriptor,
137 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000138{
telsoa01c577f2c2018-08-31 09:22:23 +0100139 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
140 reasonIfUnsupported,
141 input,
142 output,
143 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000144}
145
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100146bool ClLayerSupport::IsAdditionSupported(const TensorInfo& input0,
147 const TensorInfo& input1,
148 const TensorInfo& output,
149 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000150{
arovir01085f0a42018-10-08 14:48:19 +0100151 FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
152 reasonIfUnsupported,
153 input0,
154 input1,
155 output);
telsoa014fcda012018-03-09 14:13:49 +0000156}
157
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100158bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
159 const TensorInfo& output,
160 const TensorInfo& mean,
161 const TensorInfo& var,
162 const TensorInfo& beta,
163 const TensorInfo& gamma,
164 const BatchNormalizationDescriptor& descriptor,
165 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000166{
telsoa01c577f2c2018-08-31 09:22:23 +0100167 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
168 reasonIfUnsupported,
169 input,
170 output,
171 mean,
172 var,
173 beta,
174 gamma,
175 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000176}
177
Mike Kelly831faed2018-11-28 11:52:08 +0000178bool ClLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
179 const TensorInfo& output,
180 const BatchToSpaceNdDescriptor& descriptor,
181 Optional<std::string&> reasonIfUnsupported) const
182{
183 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchToSpaceNdWorkloadValidate,
184 reasonIfUnsupported,
185 input,
186 output,
187 descriptor);
188}
189
Jim Flynn906f9462019-05-10 13:55:21 +0100190bool ClLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
191 const TensorInfo& output,
192 const OriginsDescriptor& descriptor,
193 Optional<std::string&> reasonIfUnsupported) const
194{
195 ARMNN_NO_DEPRECATE_WARN_BEGIN
196 return IsMergerSupported(inputs, output, descriptor, reasonIfUnsupported);
197 ARMNN_NO_DEPRECATE_WARN_END
198}
199
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100200bool ClLayerSupport::IsConstantSupported(const TensorInfo& output,
201 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000202{
203 return IsSupportedForDataTypeCl(reasonIfUnsupported,
204 output.GetDataType(),
205 &TrueFunc<>,
206 &FalseFuncU8<>);
207}
208
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100209bool ClLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
210 const TensorInfo& output,
211 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000212{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100213 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
214 reasonIfUnsupported,
215 input,
216 output);
telsoa014fcda012018-03-09 14:13:49 +0000217}
218
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100219bool ClLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
220 const TensorInfo& output,
221 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000222{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100223 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
224 reasonIfUnsupported,
225 input,
226 output);
telsoa014fcda012018-03-09 14:13:49 +0000227}
228
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100229bool ClLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
230 const TensorInfo& output,
231 const Convolution2dDescriptor& descriptor,
232 const TensorInfo& weights,
233 const Optional<TensorInfo>& biases,
234 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000235{
surmeh013537c2c2018-05-18 16:31:43 +0100236 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
237 reasonIfUnsupported,
238 input,
239 output,
240 descriptor,
241 weights,
242 biases);
telsoa014fcda012018-03-09 14:13:49 +0000243}
244
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100245bool ClLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
246 const TensorInfo& output,
247 const DepthwiseConvolution2dDescriptor& descriptor,
248 const TensorInfo& weights,
249 const Optional<TensorInfo>& biases,
250 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000251{
telsoa01c577f2c2018-08-31 09:22:23 +0100252 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
253 reasonIfUnsupported,
254 input,
255 output,
256 descriptor,
257 weights,
258 biases);
telsoa014fcda012018-03-09 14:13:49 +0000259}
260
Pablo Tellof0bd6832019-04-26 17:58:13 +0100261bool ClLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
262 const TensorInfo& output,
263 const DepthwiseConvolution2dDescriptor& descriptor,
264 const TensorInfo& weights,
265 const Optional<TensorInfo>& biases,
266 Optional<std::string&> reasonIfUnsupported) const
267{
268 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
269 reasonIfUnsupported,
270 input,
271 output,
272 descriptor,
273 weights,
274 biases);
275}
276
277
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100278bool ClLayerSupport::IsDivisionSupported(const TensorInfo& input0,
279 const TensorInfo& input1,
280 const TensorInfo& output,
281 Optional<std::string&> reasonIfUnsupported) const
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100282{
283 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
284 reasonIfUnsupported,
285 input0,
286 input1,
287 output);
288}
289
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100290bool ClLayerSupport::IsFloorSupported(const TensorInfo& input,
291 const TensorInfo& output,
292 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000293{
294 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100295 return IsClBackendSupported(reasonIfUnsupported) &&
296 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
297 input.GetDataType(),
298 &FalseFuncF16<>,
299 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000300 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000301 &FalseFuncI32<>,
302 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100303}
304
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100305bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
306 const TensorInfo& output,
307 const TensorInfo& weights,
308 const TensorInfo& biases,
309 const FullyConnectedDescriptor& descriptor,
310 Optional<std::string&> reasonIfUnsupported) const
311{
312 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
313 reasonIfUnsupported,
314 input,
315 output,
316 weights,
317 biases,
318 descriptor);
319}
320
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000321bool ClLayerSupport::IsGreaterSupported(const TensorInfo& input0,
322 const TensorInfo& input1,
323 const TensorInfo& output,
324 Optional<std::string&> reasonIfUnsupported) const
325{
326 FORWARD_WORKLOAD_VALIDATE_FUNC(ClGreaterWorkloadValidate,
327 reasonIfUnsupported,
328 input0,
329 input1,
330 output);
331}
332
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100333bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
334 Optional<std::string&> reasonIfUnsupported) const
335{
336 return IsSupportedForDataTypeCl(reasonIfUnsupported,
337 input.GetDataType(),
338 &TrueFunc<>,
339 &TrueFunc<>);
340}
341
342bool ClLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
343 const TensorInfo& output,
344 const L2NormalizationDescriptor& descriptor,
345 Optional<std::string&> reasonIfUnsupported) const
346{
347 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate,
348 reasonIfUnsupported,
349 input,
350 output,
351 descriptor);
352}
353
354bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
355 const TensorInfo& outputStateIn,
356 const TensorInfo& cellStateIn,
357 const TensorInfo& scratchBuffer,
358 const TensorInfo& outputStateOut,
359 const TensorInfo& cellStateOut,
360 const TensorInfo& output,
361 const LstmDescriptor& descriptor,
362 const TensorInfo& inputToForgetWeights,
363 const TensorInfo& inputToCellWeights,
364 const TensorInfo& inputToOutputWeights,
365 const TensorInfo& recurrentToForgetWeights,
366 const TensorInfo& recurrentToCellWeights,
367 const TensorInfo& recurrentToOutputWeights,
368 const TensorInfo& forgetGateBias,
369 const TensorInfo& cellBias,
370 const TensorInfo& outputGateBias,
371 const TensorInfo* inputToInputWeights,
372 const TensorInfo* recurrentToInputWeights,
373 const TensorInfo* cellToInputWeights,
374 const TensorInfo* inputGateBias,
375 const TensorInfo* projectionWeights,
376 const TensorInfo* projectionBias,
377 const TensorInfo* cellToForgetWeights,
378 const TensorInfo* cellToOutputWeights,
379 Optional<std::string&> reasonIfUnsupported) const
telsoa01c577f2c2018-08-31 09:22:23 +0100380{
arovir01085f0a42018-10-08 14:48:19 +0100381 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
382 reasonIfUnsupported,
383 input,
384 outputStateIn,
385 cellStateIn,
386 scratchBuffer,
387 outputStateOut,
388 cellStateOut,
389 output,
390 descriptor,
391 inputToForgetWeights,
392 inputToCellWeights,
393 inputToOutputWeights,
394 recurrentToForgetWeights,
395 recurrentToCellWeights,
396 recurrentToOutputWeights,
397 forgetGateBias,
398 cellBias,
399 outputGateBias,
400 inputToInputWeights,
401 recurrentToInputWeights,
402 cellToInputWeights,
403 inputGateBias,
404 projectionWeights,
405 projectionBias,
406 cellToForgetWeights,
407 cellToOutputWeights);
telsoa01c577f2c2018-08-31 09:22:23 +0100408}
409
keidav01a959ee52018-12-19 10:04:58 +0000410bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0,
411 const TensorInfo& input1,
412 const TensorInfo& output,
413 Optional<std::string&> reasonIfUnsupported) const
414{
415 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMaximumWorkloadValidate,
416 reasonIfUnsupported,
417 input0,
418 input1,
419 output);
420}
421
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100422bool ClLayerSupport::IsMeanSupported(const TensorInfo& input,
423 const TensorInfo& output,
424 const MeanDescriptor& descriptor,
425 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100426{
Matteo Martincigh28dcab62018-10-19 16:40:03 +0100427 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMeanValidate,
428 reasonIfUnsupported,
429 input,
430 output,
431 descriptor);
narpra0132b90462018-09-13 11:07:48 +0100432}
433
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000434bool ClLayerSupport::IsMemCopySupported(const TensorInfo &input,
435 const TensorInfo &output,
436 Optional<std::string &> reasonIfUnsupported) const
437{
438 ignore_unused(input);
439 ignore_unused(output);
440 return true;
441}
442
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100443bool ClLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000444 const TensorInfo& output,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100445 const OriginsDescriptor& descriptor,
446 Optional<std::string&> reasonIfUnsupported) const
447{
Derek Lamberti0790dce2019-04-15 18:37:35 +0100448 if (descriptor.GetNumDimensions() <= descriptor.GetConcatAxis())
449 {
450 SetValueChecked(reasonIfUnsupported, "Cl Merger: Concat axis > Number of dimensions.");
451 return false;
452 }
453
454 unsigned int concatInnerAxis = (descriptor.GetNumDimensions() - descriptor.GetConcatAxis()) - 1;
455 if(concatInnerAxis < 3) // Width, height, or channels
Nikhil Raj8599a412018-11-19 14:51:07 +0000456 {
Jim Flynn69059412019-05-17 13:03:57 +0100457 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConcatWorkloadValidate,
Nikhil Raj8599a412018-11-19 14:51:07 +0000458 reasonIfUnsupported,
459 inputs,
460 output,
461 descriptor);
462 }
Derek Lamberti0790dce2019-04-15 18:37:35 +0100463 else if (concatInnerAxis == 3)
Nikhil Raj8599a412018-11-19 14:51:07 +0000464 {
Derek Lamberti0790dce2019-04-15 18:37:35 +0100465 // We rely on the sub-tensor optimization to handle the batch dimension for 4D tensors. If we can't use
466 // sub-tensors for this then we can't support it. Here is where we check that the sub-tensors will work.
467 for (auto& input : inputs)
468 {
469 if (input && !output.IsTypeSpaceMatch(*input)) // Cannot use sub-tensors if the types are not same space
470 {
471 SetValueChecked(reasonIfUnsupported, "Cl Merger: Types and quantization parameters must match.");
472 return false;
473 }
474 }
475 return true; // Sub-tensors support concat along batch
476 }
477 else // > 4 dimensions not supported.
478 {
479 SetValueChecked(reasonIfUnsupported, "Cl Merger: Maximum of 4 dimensions supported.");
480 return false;
Nikhil Raj8599a412018-11-19 14:51:07 +0000481 }
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100482}
483
saoste019292aa32019-01-08 13:55:59 +0000484bool ClLayerSupport::IsMinimumSupported(const TensorInfo& input0,
485 const TensorInfo& input1,
486 const TensorInfo& output,
487 Optional<std::string&> reasonIfUnsupported) const
488{
489 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMinimumWorkloadValidate,
490 reasonIfUnsupported,
491 input0,
492 input1,
493 output);
494}
495
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100496bool ClLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
497 const TensorInfo& input1,
498 const TensorInfo& output,
499 Optional<std::string&> reasonIfUnsupported) const
500{
501 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
502 reasonIfUnsupported,
503 input0,
504 input1,
505 output);
506}
507
508bool ClLayerSupport::IsNormalizationSupported(const TensorInfo& input,
509 const TensorInfo& output,
510 const NormalizationDescriptor& descriptor,
511 Optional<std::string&> reasonIfUnsupported) const
512{
513 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
514}
515
516bool ClLayerSupport::IsOutputSupported(const TensorInfo& output,
517 Optional<std::string&> reasonIfUnsupported) const
518{
kevmay012b4d88e2019-01-24 14:05:09 +0000519 return IsClBackendSupported(reasonIfUnsupported) &&
520 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
521 output.GetDataType(),
522 &TrueFunc<>,
523 &TrueFunc<>,
524 &TrueFunc<>,
525 &FalseFuncI32<>,
526 &TrueFunc<>);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100527}
528
529bool ClLayerSupport::IsPadSupported(const TensorInfo& input,
530 const TensorInfo& output,
531 const PadDescriptor& descriptor,
532 Optional<std::string&> reasonIfUnsupported) const
arovir01085f0a42018-10-08 14:48:19 +0100533{
534 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate,
535 reasonIfUnsupported,
536 input,
537 output,
538 descriptor);
539}
540
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100541bool ClLayerSupport::IsPermuteSupported(const TensorInfo& input,
542 const TensorInfo& output,
543 const PermuteDescriptor& descriptor,
544 Optional<std::string&> reasonIfUnsupported) const
545{
546 ignore_unused(input);
547 ignore_unused(output);
548 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000549}
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100550
551bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
552 const TensorInfo& output,
553 const Pooling2dDescriptor& descriptor,
554 Optional<std::string&> reasonIfUnsupported) const
555{
556 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
557}
558
559bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000560 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100561 Optional<std::string&> reasonIfUnsupported) const
562{
563 ignore_unused(input);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000564 ignore_unused(descriptor);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100565 ignore_unused(reasonIfUnsupported);
566 return true;
567}
568
569bool ClLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000570 const TensorInfo& output,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100571 Optional<std::string&> reasonIfUnsupported) const
572{
Sadik Armaganc625f002018-12-17 11:32:16 +0000573 ignore_unused(output);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100574 return IsSupportedForDataTypeCl(reasonIfUnsupported,
575 input.GetDataType(),
576 &TrueFunc<>,
577 &FalseFuncU8<>);
578}
579
580bool ClLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
581 const TensorInfo& output,
582 const SoftmaxDescriptor& descriptor,
583 Optional<std::string&> reasonIfUnsupported) const
584{
585 ignore_unused(descriptor);
586 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
587}
588
Sadik Armaganf4464322018-12-20 16:19:12 +0000589bool ClLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
590 const TensorInfo& output,
591 const SpaceToBatchNdDescriptor& descriptor,
592 Optional<std::string&> reasonIfUnsupported) const
593{
594 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSpaceToBatchNdWorkloadValidate,
595 reasonIfUnsupported,
596 input,
597 output,
598 descriptor);
599}
600
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100601bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
602 const ViewsDescriptor& descriptor,
603 Optional<std::string&> reasonIfUnsupported) const
604{
605 ignore_unused(descriptor);
606 return IsSupportedForDataTypeCl(reasonIfUnsupported,
607 input.GetDataType(),
608 &TrueFunc<>,
609 &TrueFunc<>);
610}
611
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100612bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
613 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
614 const ViewsDescriptor& descriptor,
615 Optional<std::string&> reasonIfUnsupported) const
616{
Narumol Prangnawarat74135832019-05-23 15:07:33 +0100617#if defined(ARMCOMPUTECL_ENABLED)
618 // Split along the last dimension, cannot use sub-tensors
619 // as width and height of the sub-tensors do not match
620 // the width and height of the parent tensor
621 // in case of input with more than 2D.
622 std::set<unsigned int> splitAxis = ComputeSplitAxis(descriptor, input.GetShape());
623 if (descriptor.GetNumDimensions() > 2 && splitAxis.size() == 1 &&
624 *splitAxis.begin() == descriptor.GetNumDimensions() - 1 )
625 {
626 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSplitterWorkloadValidate,
627 reasonIfUnsupported,
628 input,
629 outputs,
630 *splitAxis.begin());
631 }
632#endif
633 for (auto output : outputs)
634 {
635 if (!input.IsTypeSpaceMatch(output)) // Cannot use sub-tensors if the types are not same space
636 {
637 SetValueChecked(reasonIfUnsupported, "Cl Splitter: Types and quantization parameters must match.");
638 return false;
639 }
640 }
641 return true;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100642}
643
keidav01d74dc912018-12-10 18:16:07 +0000644bool ClLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
645 const TensorInfo& output,
646 const StridedSliceDescriptor& descriptor,
647 Optional<std::string&> reasonIfUnsupported) const
648{
649 FORWARD_WORKLOAD_VALIDATE_FUNC(ClStridedSliceWorkloadValidate,
650 reasonIfUnsupported,
651 input,
652 output,
653 descriptor);
654}
655
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100656bool ClLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
657 const TensorInfo& input1,
658 const TensorInfo& output,
659 Optional<std::string&> reasonIfUnsupported) const
660{
661 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
662 reasonIfUnsupported,
663 input0,
664 input1,
665 output);
666}
667
668} // namespace armnn