blob: a557870ceaa7c0a427691b5f57f4d2b9281f4bae [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "ClLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "ClBackendId.hpp"
arovir017c22c702018-10-09 11:16:46 +01008
David Beck3cc9a622018-10-12 10:38:31 +01009#include <armnn/Descriptors.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
David Beck111b5d92018-11-12 14:59:37 +000013#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010014
telsoa014fcda012018-03-09 14:13:49 +000015#include <boost/core/ignore_unused.hpp>
16
Matteo Martincighd95e9062019-01-31 15:35:59 +000017#if defined(ARMCOMPUTECL_ENABLED)
David Beckac42efd2018-09-26 17:41:13 +010018#include "workloads/ClAdditionWorkload.hpp"
Nattapat Chaimanowonge06757e2018-10-11 15:39:18 +010019#include "workloads/ClActivationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010020#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
Mike Kelly831faed2018-11-28 11:52:08 +000021#include "workloads/ClBatchToSpaceNdWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010022#include "workloads/ClConvertFp16ToFp32Workload.hpp"
23#include "workloads/ClConvertFp32ToFp16Workload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010024#include "workloads/ClConvolution2dWorkload.hpp"
Matthew Benthamd8777392018-10-08 09:38:55 +010025#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010026#include "workloads/ClDivisionFloatWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010027#include "workloads/ClFullyConnectedWorkload.hpp"
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +000028#include "workloads/ClGreaterWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010029#include "workloads/ClL2NormalizationFloatWorkload.hpp"
30#include "workloads/ClLstmFloatWorkload.hpp"
keidav01a959ee52018-12-19 10:04:58 +000031#include "workloads/ClMaximumWorkload.hpp"
Matteo Martincigh28dcab62018-10-19 16:40:03 +010032#include "workloads/ClMeanWorkload.hpp"
Nikhil Raj8599a412018-11-19 14:51:07 +000033#include "workloads/ClMergerWorkload.hpp"
saoste019292aa32019-01-08 13:55:59 +000034#include "workloads/ClMinimumWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010035#include "workloads/ClMultiplicationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010036#include "workloads/ClNormalizationFloatWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010037#include "workloads/ClPadWorkload.hpp"
38#include "workloads/ClPermuteWorkload.hpp"
Nattapat Chaimanowongac9e0962018-10-10 17:18:35 +010039#include "workloads/ClPooling2dWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010040#include "workloads/ClSoftmaxBaseWorkload.hpp"
Sadik Armaganf4464322018-12-20 16:19:12 +000041#include "workloads/ClSpaceToBatchNdWorkload.hpp"
keidav01d74dc912018-12-10 18:16:07 +000042#include "workloads/ClStridedSliceWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010043#include "workloads/ClSubtractionWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000044#endif
45
46using namespace boost;
47
48namespace armnn
49{
arovir017c22c702018-10-09 11:16:46 +010050
telsoa014fcda012018-03-09 14:13:49 +000051namespace
52{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010053
telsoa014fcda012018-03-09 14:13:49 +000054template<unsigned int FilterSize>
55bool IsMatchingSize2d(const TensorInfo& weightInfo)
56{
telsoa01c577f2c2018-08-31 09:22:23 +010057 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +000058 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
59}
60
61template<uint32_t ValidStride>
62bool IsMatchingStride(uint32_t actualStride)
63{
64 return ValidStride == actualStride;
65}
66
67template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
68bool IsMatchingStride(uint32_t actualStride)
69{
70 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010071}
telsoa014fcda012018-03-09 14:13:49 +000072
arovir01085f0a42018-10-08 14:48:19 +010073bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000074{
Matteo Martincighd95e9062019-01-31 15:35:59 +000075#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000076 return true;
77#else
arovir01085f0a42018-10-08 14:48:19 +010078 if (reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000079 {
arovir01085f0a42018-10-08 14:48:19 +010080 reasonIfUnsupported.value() = "The armnn library has been built without CL support";
telsoa014fcda012018-03-09 14:13:49 +000081 }
82 return false;
83#endif
84}
85
Matteo Martincighd95e9062019-01-31 15:35:59 +000086#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000087#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
88#else
89#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
90#endif
91
Matteo Martincighd95e9062019-01-31 15:35:59 +000092#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000093template<class FuncType, class... Args>
arovir01085f0a42018-10-08 14:48:19 +010094inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
telsoa014fcda012018-03-09 14:13:49 +000095{
96 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
97 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
98 if (!supported && reasonIfUnsupported)
99 {
arovir01085f0a42018-10-08 14:48:19 +0100100 reasonIfUnsupported.value() = aclStatus.error_description();
telsoa014fcda012018-03-09 14:13:49 +0000101 }
102 return supported;
103}
104
105#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
106 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
107#else
108#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
109 return IsClBackendSupported(reasonIfUnsupported);
110#endif
111
telsoa01c577f2c2018-08-31 09:22:23 +0100112template<typename FloatFunc, typename Uint8Func, typename ... Params>
arovir01085f0a42018-10-08 14:48:19 +0100113bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000114 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100115 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000116 Uint8Func uint8FuncPtr,
117 Params&&... params)
118{
119 return IsClBackendSupported(reasonIfUnsupported) &&
120 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
121 dataType,
122 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100123 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000124 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +0000125 &FalseFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000126 &FalseFunc<>,
telsoa014fcda012018-03-09 14:13:49 +0000127 std::forward<Params>(params)...);
128}
129
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100130} // anonymous namespace
131
132bool ClLayerSupport::IsActivationSupported(const TensorInfo& input,
133 const TensorInfo& output,
134 const ActivationDescriptor& descriptor,
135 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000136{
telsoa01c577f2c2018-08-31 09:22:23 +0100137 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
138 reasonIfUnsupported,
139 input,
140 output,
141 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000142}
143
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100144bool ClLayerSupport::IsAdditionSupported(const TensorInfo& input0,
145 const TensorInfo& input1,
146 const TensorInfo& output,
147 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000148{
arovir01085f0a42018-10-08 14:48:19 +0100149 FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
150 reasonIfUnsupported,
151 input0,
152 input1,
153 output);
telsoa014fcda012018-03-09 14:13:49 +0000154}
155
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100156bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
157 const TensorInfo& output,
158 const TensorInfo& mean,
159 const TensorInfo& var,
160 const TensorInfo& beta,
161 const TensorInfo& gamma,
162 const BatchNormalizationDescriptor& descriptor,
163 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000164{
telsoa01c577f2c2018-08-31 09:22:23 +0100165 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
166 reasonIfUnsupported,
167 input,
168 output,
169 mean,
170 var,
171 beta,
172 gamma,
173 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000174}
175
Mike Kelly831faed2018-11-28 11:52:08 +0000176bool ClLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
177 const TensorInfo& output,
178 const BatchToSpaceNdDescriptor& descriptor,
179 Optional<std::string&> reasonIfUnsupported) const
180{
181 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchToSpaceNdWorkloadValidate,
182 reasonIfUnsupported,
183 input,
184 output,
185 descriptor);
186}
187
Jim Flynn906f9462019-05-10 13:55:21 +0100188bool ClLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
189 const TensorInfo& output,
190 const OriginsDescriptor& descriptor,
191 Optional<std::string&> reasonIfUnsupported) const
192{
193 ARMNN_NO_DEPRECATE_WARN_BEGIN
194 return IsMergerSupported(inputs, output, descriptor, reasonIfUnsupported);
195 ARMNN_NO_DEPRECATE_WARN_END
196}
197
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100198bool ClLayerSupport::IsConstantSupported(const TensorInfo& output,
199 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000200{
201 return IsSupportedForDataTypeCl(reasonIfUnsupported,
202 output.GetDataType(),
203 &TrueFunc<>,
204 &FalseFuncU8<>);
205}
206
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100207bool ClLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
208 const TensorInfo& output,
209 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000210{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100211 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
212 reasonIfUnsupported,
213 input,
214 output);
telsoa014fcda012018-03-09 14:13:49 +0000215}
216
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100217bool ClLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
218 const TensorInfo& output,
219 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000220{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100221 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
222 reasonIfUnsupported,
223 input,
224 output);
telsoa014fcda012018-03-09 14:13:49 +0000225}
226
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100227bool ClLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
228 const TensorInfo& output,
229 const Convolution2dDescriptor& descriptor,
230 const TensorInfo& weights,
231 const Optional<TensorInfo>& biases,
232 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000233{
surmeh013537c2c2018-05-18 16:31:43 +0100234 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
235 reasonIfUnsupported,
236 input,
237 output,
238 descriptor,
239 weights,
240 biases);
telsoa014fcda012018-03-09 14:13:49 +0000241}
242
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100243bool ClLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
244 const TensorInfo& output,
245 const DepthwiseConvolution2dDescriptor& descriptor,
246 const TensorInfo& weights,
247 const Optional<TensorInfo>& biases,
248 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000249{
telsoa01c577f2c2018-08-31 09:22:23 +0100250 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
251 reasonIfUnsupported,
252 input,
253 output,
254 descriptor,
255 weights,
256 biases);
telsoa014fcda012018-03-09 14:13:49 +0000257}
258
Pablo Tellof0bd6832019-04-26 17:58:13 +0100259bool ClLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
260 const TensorInfo& output,
261 const DepthwiseConvolution2dDescriptor& descriptor,
262 const TensorInfo& weights,
263 const Optional<TensorInfo>& biases,
264 Optional<std::string&> reasonIfUnsupported) const
265{
266 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
267 reasonIfUnsupported,
268 input,
269 output,
270 descriptor,
271 weights,
272 biases);
273}
274
275
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100276bool ClLayerSupport::IsDivisionSupported(const TensorInfo& input0,
277 const TensorInfo& input1,
278 const TensorInfo& output,
279 Optional<std::string&> reasonIfUnsupported) const
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100280{
281 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
282 reasonIfUnsupported,
283 input0,
284 input1,
285 output);
286}
287
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100288bool ClLayerSupport::IsFloorSupported(const TensorInfo& input,
289 const TensorInfo& output,
290 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000291{
292 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100293 return IsClBackendSupported(reasonIfUnsupported) &&
294 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
295 input.GetDataType(),
296 &FalseFuncF16<>,
297 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000298 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000299 &FalseFuncI32<>,
300 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100301}
302
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100303bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
304 const TensorInfo& output,
305 const TensorInfo& weights,
306 const TensorInfo& biases,
307 const FullyConnectedDescriptor& descriptor,
308 Optional<std::string&> reasonIfUnsupported) const
309{
310 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
311 reasonIfUnsupported,
312 input,
313 output,
314 weights,
315 biases,
316 descriptor);
317}
318
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000319bool ClLayerSupport::IsGreaterSupported(const TensorInfo& input0,
320 const TensorInfo& input1,
321 const TensorInfo& output,
322 Optional<std::string&> reasonIfUnsupported) const
323{
324 FORWARD_WORKLOAD_VALIDATE_FUNC(ClGreaterWorkloadValidate,
325 reasonIfUnsupported,
326 input0,
327 input1,
328 output);
329}
330
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100331bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
332 Optional<std::string&> reasonIfUnsupported) const
333{
334 return IsSupportedForDataTypeCl(reasonIfUnsupported,
335 input.GetDataType(),
336 &TrueFunc<>,
337 &TrueFunc<>);
338}
339
340bool ClLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
341 const TensorInfo& output,
342 const L2NormalizationDescriptor& descriptor,
343 Optional<std::string&> reasonIfUnsupported) const
344{
345 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate,
346 reasonIfUnsupported,
347 input,
348 output,
349 descriptor);
350}
351
352bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
353 const TensorInfo& outputStateIn,
354 const TensorInfo& cellStateIn,
355 const TensorInfo& scratchBuffer,
356 const TensorInfo& outputStateOut,
357 const TensorInfo& cellStateOut,
358 const TensorInfo& output,
359 const LstmDescriptor& descriptor,
360 const TensorInfo& inputToForgetWeights,
361 const TensorInfo& inputToCellWeights,
362 const TensorInfo& inputToOutputWeights,
363 const TensorInfo& recurrentToForgetWeights,
364 const TensorInfo& recurrentToCellWeights,
365 const TensorInfo& recurrentToOutputWeights,
366 const TensorInfo& forgetGateBias,
367 const TensorInfo& cellBias,
368 const TensorInfo& outputGateBias,
369 const TensorInfo* inputToInputWeights,
370 const TensorInfo* recurrentToInputWeights,
371 const TensorInfo* cellToInputWeights,
372 const TensorInfo* inputGateBias,
373 const TensorInfo* projectionWeights,
374 const TensorInfo* projectionBias,
375 const TensorInfo* cellToForgetWeights,
376 const TensorInfo* cellToOutputWeights,
377 Optional<std::string&> reasonIfUnsupported) const
telsoa01c577f2c2018-08-31 09:22:23 +0100378{
arovir01085f0a42018-10-08 14:48:19 +0100379 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
380 reasonIfUnsupported,
381 input,
382 outputStateIn,
383 cellStateIn,
384 scratchBuffer,
385 outputStateOut,
386 cellStateOut,
387 output,
388 descriptor,
389 inputToForgetWeights,
390 inputToCellWeights,
391 inputToOutputWeights,
392 recurrentToForgetWeights,
393 recurrentToCellWeights,
394 recurrentToOutputWeights,
395 forgetGateBias,
396 cellBias,
397 outputGateBias,
398 inputToInputWeights,
399 recurrentToInputWeights,
400 cellToInputWeights,
401 inputGateBias,
402 projectionWeights,
403 projectionBias,
404 cellToForgetWeights,
405 cellToOutputWeights);
telsoa01c577f2c2018-08-31 09:22:23 +0100406}
407
keidav01a959ee52018-12-19 10:04:58 +0000408bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0,
409 const TensorInfo& input1,
410 const TensorInfo& output,
411 Optional<std::string&> reasonIfUnsupported) const
412{
413 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMaximumWorkloadValidate,
414 reasonIfUnsupported,
415 input0,
416 input1,
417 output);
418}
419
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100420bool ClLayerSupport::IsMeanSupported(const TensorInfo& input,
421 const TensorInfo& output,
422 const MeanDescriptor& descriptor,
423 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100424{
Matteo Martincigh28dcab62018-10-19 16:40:03 +0100425 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMeanValidate,
426 reasonIfUnsupported,
427 input,
428 output,
429 descriptor);
narpra0132b90462018-09-13 11:07:48 +0100430}
431
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000432bool ClLayerSupport::IsMemCopySupported(const TensorInfo &input,
433 const TensorInfo &output,
434 Optional<std::string &> reasonIfUnsupported) const
435{
436 ignore_unused(input);
437 ignore_unused(output);
438 return true;
439}
440
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100441bool ClLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000442 const TensorInfo& output,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100443 const OriginsDescriptor& descriptor,
444 Optional<std::string&> reasonIfUnsupported) const
445{
Derek Lamberti0790dce2019-04-15 18:37:35 +0100446 if (descriptor.GetNumDimensions() <= descriptor.GetConcatAxis())
447 {
448 SetValueChecked(reasonIfUnsupported, "Cl Merger: Concat axis > Number of dimensions.");
449 return false;
450 }
451
452 unsigned int concatInnerAxis = (descriptor.GetNumDimensions() - descriptor.GetConcatAxis()) - 1;
453 if(concatInnerAxis < 3) // Width, height, or channels
Nikhil Raj8599a412018-11-19 14:51:07 +0000454 {
455 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMergerWorkloadValidate,
456 reasonIfUnsupported,
457 inputs,
458 output,
459 descriptor);
460 }
Derek Lamberti0790dce2019-04-15 18:37:35 +0100461 else if (concatInnerAxis == 3)
Nikhil Raj8599a412018-11-19 14:51:07 +0000462 {
Derek Lamberti0790dce2019-04-15 18:37:35 +0100463 // We rely on the sub-tensor optimization to handle the batch dimension for 4D tensors. If we can't use
464 // sub-tensors for this then we can't support it. Here is where we check that the sub-tensors will work.
465 for (auto& input : inputs)
466 {
467 if (input && !output.IsTypeSpaceMatch(*input)) // Cannot use sub-tensors if the types are not same space
468 {
469 SetValueChecked(reasonIfUnsupported, "Cl Merger: Types and quantization parameters must match.");
470 return false;
471 }
472 }
473 return true; // Sub-tensors support concat along batch
474 }
475 else // > 4 dimensions not supported.
476 {
477 SetValueChecked(reasonIfUnsupported, "Cl Merger: Maximum of 4 dimensions supported.");
478 return false;
Nikhil Raj8599a412018-11-19 14:51:07 +0000479 }
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100480}
481
saoste019292aa32019-01-08 13:55:59 +0000482bool ClLayerSupport::IsMinimumSupported(const TensorInfo& input0,
483 const TensorInfo& input1,
484 const TensorInfo& output,
485 Optional<std::string&> reasonIfUnsupported) const
486{
487 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMinimumWorkloadValidate,
488 reasonIfUnsupported,
489 input0,
490 input1,
491 output);
492}
493
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100494bool ClLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
495 const TensorInfo& input1,
496 const TensorInfo& output,
497 Optional<std::string&> reasonIfUnsupported) const
498{
499 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
500 reasonIfUnsupported,
501 input0,
502 input1,
503 output);
504}
505
506bool ClLayerSupport::IsNormalizationSupported(const TensorInfo& input,
507 const TensorInfo& output,
508 const NormalizationDescriptor& descriptor,
509 Optional<std::string&> reasonIfUnsupported) const
510{
511 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
512}
513
514bool ClLayerSupport::IsOutputSupported(const TensorInfo& output,
515 Optional<std::string&> reasonIfUnsupported) const
516{
kevmay012b4d88e2019-01-24 14:05:09 +0000517 return IsClBackendSupported(reasonIfUnsupported) &&
518 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
519 output.GetDataType(),
520 &TrueFunc<>,
521 &TrueFunc<>,
522 &TrueFunc<>,
523 &FalseFuncI32<>,
524 &TrueFunc<>);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100525}
526
527bool ClLayerSupport::IsPadSupported(const TensorInfo& input,
528 const TensorInfo& output,
529 const PadDescriptor& descriptor,
530 Optional<std::string&> reasonIfUnsupported) const
arovir01085f0a42018-10-08 14:48:19 +0100531{
532 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate,
533 reasonIfUnsupported,
534 input,
535 output,
536 descriptor);
537}
538
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100539bool ClLayerSupport::IsPermuteSupported(const TensorInfo& input,
540 const TensorInfo& output,
541 const PermuteDescriptor& descriptor,
542 Optional<std::string&> reasonIfUnsupported) const
543{
544 ignore_unused(input);
545 ignore_unused(output);
546 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000547}
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100548
549bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
550 const TensorInfo& output,
551 const Pooling2dDescriptor& descriptor,
552 Optional<std::string&> reasonIfUnsupported) const
553{
554 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
555}
556
557bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000558 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100559 Optional<std::string&> reasonIfUnsupported) const
560{
561 ignore_unused(input);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000562 ignore_unused(descriptor);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100563 ignore_unused(reasonIfUnsupported);
564 return true;
565}
566
567bool ClLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000568 const TensorInfo& output,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100569 Optional<std::string&> reasonIfUnsupported) const
570{
Sadik Armaganc625f002018-12-17 11:32:16 +0000571 ignore_unused(output);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100572 return IsSupportedForDataTypeCl(reasonIfUnsupported,
573 input.GetDataType(),
574 &TrueFunc<>,
575 &FalseFuncU8<>);
576}
577
578bool ClLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
579 const TensorInfo& output,
580 const SoftmaxDescriptor& descriptor,
581 Optional<std::string&> reasonIfUnsupported) const
582{
583 ignore_unused(descriptor);
584 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
585}
586
Sadik Armaganf4464322018-12-20 16:19:12 +0000587bool ClLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
588 const TensorInfo& output,
589 const SpaceToBatchNdDescriptor& descriptor,
590 Optional<std::string&> reasonIfUnsupported) const
591{
592 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSpaceToBatchNdWorkloadValidate,
593 reasonIfUnsupported,
594 input,
595 output,
596 descriptor);
597}
598
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100599bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
600 const ViewsDescriptor& descriptor,
601 Optional<std::string&> reasonIfUnsupported) const
602{
603 ignore_unused(descriptor);
604 return IsSupportedForDataTypeCl(reasonIfUnsupported,
605 input.GetDataType(),
606 &TrueFunc<>,
607 &TrueFunc<>);
608}
609
keidav01d74dc912018-12-10 18:16:07 +0000610bool ClLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
611 const TensorInfo& output,
612 const StridedSliceDescriptor& descriptor,
613 Optional<std::string&> reasonIfUnsupported) const
614{
615 FORWARD_WORKLOAD_VALIDATE_FUNC(ClStridedSliceWorkloadValidate,
616 reasonIfUnsupported,
617 input,
618 output,
619 descriptor);
620}
621
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100622bool ClLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
623 const TensorInfo& input1,
624 const TensorInfo& output,
625 Optional<std::string&> reasonIfUnsupported) const
626{
627 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
628 reasonIfUnsupported,
629 input0,
630 input1,
631 output);
632}
633
634} // namespace armnn