blob: 73c9e49c4f0ed36259dcfae8cbb2db791c7723d9 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "ClLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "ClBackendId.hpp"
arovir017c22c702018-10-09 11:16:46 +01008
David Beck3cc9a622018-10-12 10:38:31 +01009#include <armnn/Descriptors.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
David Beck111b5d92018-11-12 14:59:37 +000013#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010014
telsoa014fcda012018-03-09 14:13:49 +000015#include <boost/core/ignore_unused.hpp>
16
Matteo Martincighd95e9062019-01-31 15:35:59 +000017#if defined(ARMCOMPUTECL_ENABLED)
David Beckac42efd2018-09-26 17:41:13 +010018#include "workloads/ClAdditionWorkload.hpp"
Nattapat Chaimanowonge06757e2018-10-11 15:39:18 +010019#include "workloads/ClActivationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010020#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
Mike Kelly831faed2018-11-28 11:52:08 +000021#include "workloads/ClBatchToSpaceNdWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010022#include "workloads/ClConvertFp16ToFp32Workload.hpp"
23#include "workloads/ClConvertFp32ToFp16Workload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010024#include "workloads/ClConvolution2dWorkload.hpp"
Matthew Benthamd8777392018-10-08 09:38:55 +010025#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010026#include "workloads/ClDivisionFloatWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010027#include "workloads/ClFullyConnectedWorkload.hpp"
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +000028#include "workloads/ClGreaterWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010029#include "workloads/ClL2NormalizationFloatWorkload.hpp"
30#include "workloads/ClLstmFloatWorkload.hpp"
keidav01a959ee52018-12-19 10:04:58 +000031#include "workloads/ClMaximumWorkload.hpp"
Matteo Martincigh28dcab62018-10-19 16:40:03 +010032#include "workloads/ClMeanWorkload.hpp"
Nikhil Raj8599a412018-11-19 14:51:07 +000033#include "workloads/ClMergerWorkload.hpp"
saoste019292aa32019-01-08 13:55:59 +000034#include "workloads/ClMinimumWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010035#include "workloads/ClMultiplicationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010036#include "workloads/ClNormalizationFloatWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010037#include "workloads/ClPadWorkload.hpp"
38#include "workloads/ClPermuteWorkload.hpp"
Nattapat Chaimanowongac9e0962018-10-10 17:18:35 +010039#include "workloads/ClPooling2dWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010040#include "workloads/ClSoftmaxBaseWorkload.hpp"
Sadik Armaganf4464322018-12-20 16:19:12 +000041#include "workloads/ClSpaceToBatchNdWorkload.hpp"
keidav01d74dc912018-12-10 18:16:07 +000042#include "workloads/ClStridedSliceWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010043#include "workloads/ClSubtractionWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000044#endif
45
46using namespace boost;
47
48namespace armnn
49{
arovir017c22c702018-10-09 11:16:46 +010050
telsoa014fcda012018-03-09 14:13:49 +000051namespace
52{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010053
telsoa014fcda012018-03-09 14:13:49 +000054template<unsigned int FilterSize>
55bool IsMatchingSize2d(const TensorInfo& weightInfo)
56{
telsoa01c577f2c2018-08-31 09:22:23 +010057 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +000058 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
59}
60
61template<uint32_t ValidStride>
62bool IsMatchingStride(uint32_t actualStride)
63{
64 return ValidStride == actualStride;
65}
66
67template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
68bool IsMatchingStride(uint32_t actualStride)
69{
70 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010071}
telsoa014fcda012018-03-09 14:13:49 +000072
arovir01085f0a42018-10-08 14:48:19 +010073bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000074{
Matteo Martincighd95e9062019-01-31 15:35:59 +000075#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000076 return true;
77#else
arovir01085f0a42018-10-08 14:48:19 +010078 if (reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000079 {
arovir01085f0a42018-10-08 14:48:19 +010080 reasonIfUnsupported.value() = "The armnn library has been built without CL support";
telsoa014fcda012018-03-09 14:13:49 +000081 }
82 return false;
83#endif
84}
85
Matteo Martincighd95e9062019-01-31 15:35:59 +000086#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000087#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
88#else
89#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
90#endif
91
Matteo Martincighd95e9062019-01-31 15:35:59 +000092#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000093template<class FuncType, class... Args>
arovir01085f0a42018-10-08 14:48:19 +010094inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
telsoa014fcda012018-03-09 14:13:49 +000095{
96 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
97 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
98 if (!supported && reasonIfUnsupported)
99 {
arovir01085f0a42018-10-08 14:48:19 +0100100 reasonIfUnsupported.value() = aclStatus.error_description();
telsoa014fcda012018-03-09 14:13:49 +0000101 }
102 return supported;
103}
104
105#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
106 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
107#else
108#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
109 return IsClBackendSupported(reasonIfUnsupported);
110#endif
111
telsoa01c577f2c2018-08-31 09:22:23 +0100112template<typename FloatFunc, typename Uint8Func, typename ... Params>
arovir01085f0a42018-10-08 14:48:19 +0100113bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000114 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100115 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000116 Uint8Func uint8FuncPtr,
117 Params&&... params)
118{
119 return IsClBackendSupported(reasonIfUnsupported) &&
120 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
121 dataType,
122 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100123 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000124 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +0000125 &FalseFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000126 &FalseFunc<>,
telsoa014fcda012018-03-09 14:13:49 +0000127 std::forward<Params>(params)...);
128}
129
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100130} // anonymous namespace
131
132bool ClLayerSupport::IsActivationSupported(const TensorInfo& input,
133 const TensorInfo& output,
134 const ActivationDescriptor& descriptor,
135 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000136{
telsoa01c577f2c2018-08-31 09:22:23 +0100137 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
138 reasonIfUnsupported,
139 input,
140 output,
141 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000142}
143
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100144bool ClLayerSupport::IsAdditionSupported(const TensorInfo& input0,
145 const TensorInfo& input1,
146 const TensorInfo& output,
147 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000148{
arovir01085f0a42018-10-08 14:48:19 +0100149 FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
150 reasonIfUnsupported,
151 input0,
152 input1,
153 output);
telsoa014fcda012018-03-09 14:13:49 +0000154}
155
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100156bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
157 const TensorInfo& output,
158 const TensorInfo& mean,
159 const TensorInfo& var,
160 const TensorInfo& beta,
161 const TensorInfo& gamma,
162 const BatchNormalizationDescriptor& descriptor,
163 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000164{
telsoa01c577f2c2018-08-31 09:22:23 +0100165 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
166 reasonIfUnsupported,
167 input,
168 output,
169 mean,
170 var,
171 beta,
172 gamma,
173 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000174}
175
Mike Kelly831faed2018-11-28 11:52:08 +0000176bool ClLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
177 const TensorInfo& output,
178 const BatchToSpaceNdDescriptor& descriptor,
179 Optional<std::string&> reasonIfUnsupported) const
180{
181 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchToSpaceNdWorkloadValidate,
182 reasonIfUnsupported,
183 input,
184 output,
185 descriptor);
186}
187
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100188bool ClLayerSupport::IsConstantSupported(const TensorInfo& output,
189 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000190{
191 return IsSupportedForDataTypeCl(reasonIfUnsupported,
192 output.GetDataType(),
193 &TrueFunc<>,
194 &FalseFuncU8<>);
195}
196
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100197bool ClLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
198 const TensorInfo& output,
199 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000200{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100201 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
202 reasonIfUnsupported,
203 input,
204 output);
telsoa014fcda012018-03-09 14:13:49 +0000205}
206
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100207bool ClLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
208 const TensorInfo& output,
209 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000210{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100211 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
212 reasonIfUnsupported,
213 input,
214 output);
telsoa014fcda012018-03-09 14:13:49 +0000215}
216
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100217bool ClLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
218 const TensorInfo& output,
219 const Convolution2dDescriptor& descriptor,
220 const TensorInfo& weights,
221 const Optional<TensorInfo>& biases,
222 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000223{
surmeh013537c2c2018-05-18 16:31:43 +0100224 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
225 reasonIfUnsupported,
226 input,
227 output,
228 descriptor,
229 weights,
230 biases);
telsoa014fcda012018-03-09 14:13:49 +0000231}
232
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100233bool ClLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
234 const TensorInfo& output,
235 const DepthwiseConvolution2dDescriptor& descriptor,
236 const TensorInfo& weights,
237 const Optional<TensorInfo>& biases,
238 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000239{
telsoa01c577f2c2018-08-31 09:22:23 +0100240 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
241 reasonIfUnsupported,
242 input,
243 output,
244 descriptor,
245 weights,
246 biases);
telsoa014fcda012018-03-09 14:13:49 +0000247}
248
Pablo Tellof0bd6832019-04-26 17:58:13 +0100249bool ClLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
250 const TensorInfo& output,
251 const DepthwiseConvolution2dDescriptor& descriptor,
252 const TensorInfo& weights,
253 const Optional<TensorInfo>& biases,
254 Optional<std::string&> reasonIfUnsupported) const
255{
256 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
257 reasonIfUnsupported,
258 input,
259 output,
260 descriptor,
261 weights,
262 biases);
263}
264
265
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100266bool ClLayerSupport::IsDivisionSupported(const TensorInfo& input0,
267 const TensorInfo& input1,
268 const TensorInfo& output,
269 Optional<std::string&> reasonIfUnsupported) const
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100270{
271 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
272 reasonIfUnsupported,
273 input0,
274 input1,
275 output);
276}
277
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100278bool ClLayerSupport::IsFloorSupported(const TensorInfo& input,
279 const TensorInfo& output,
280 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000281{
282 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100283 return IsClBackendSupported(reasonIfUnsupported) &&
284 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
285 input.GetDataType(),
286 &FalseFuncF16<>,
287 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000288 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000289 &FalseFuncI32<>,
290 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100291}
292
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100293bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
294 const TensorInfo& output,
295 const TensorInfo& weights,
296 const TensorInfo& biases,
297 const FullyConnectedDescriptor& descriptor,
298 Optional<std::string&> reasonIfUnsupported) const
299{
300 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
301 reasonIfUnsupported,
302 input,
303 output,
304 weights,
305 biases,
306 descriptor);
307}
308
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000309bool ClLayerSupport::IsGreaterSupported(const TensorInfo& input0,
310 const TensorInfo& input1,
311 const TensorInfo& output,
312 Optional<std::string&> reasonIfUnsupported) const
313{
314 FORWARD_WORKLOAD_VALIDATE_FUNC(ClGreaterWorkloadValidate,
315 reasonIfUnsupported,
316 input0,
317 input1,
318 output);
319}
320
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100321bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
322 Optional<std::string&> reasonIfUnsupported) const
323{
324 return IsSupportedForDataTypeCl(reasonIfUnsupported,
325 input.GetDataType(),
326 &TrueFunc<>,
327 &TrueFunc<>);
328}
329
330bool ClLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
331 const TensorInfo& output,
332 const L2NormalizationDescriptor& descriptor,
333 Optional<std::string&> reasonIfUnsupported) const
334{
335 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate,
336 reasonIfUnsupported,
337 input,
338 output,
339 descriptor);
340}
341
342bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
343 const TensorInfo& outputStateIn,
344 const TensorInfo& cellStateIn,
345 const TensorInfo& scratchBuffer,
346 const TensorInfo& outputStateOut,
347 const TensorInfo& cellStateOut,
348 const TensorInfo& output,
349 const LstmDescriptor& descriptor,
350 const TensorInfo& inputToForgetWeights,
351 const TensorInfo& inputToCellWeights,
352 const TensorInfo& inputToOutputWeights,
353 const TensorInfo& recurrentToForgetWeights,
354 const TensorInfo& recurrentToCellWeights,
355 const TensorInfo& recurrentToOutputWeights,
356 const TensorInfo& forgetGateBias,
357 const TensorInfo& cellBias,
358 const TensorInfo& outputGateBias,
359 const TensorInfo* inputToInputWeights,
360 const TensorInfo* recurrentToInputWeights,
361 const TensorInfo* cellToInputWeights,
362 const TensorInfo* inputGateBias,
363 const TensorInfo* projectionWeights,
364 const TensorInfo* projectionBias,
365 const TensorInfo* cellToForgetWeights,
366 const TensorInfo* cellToOutputWeights,
367 Optional<std::string&> reasonIfUnsupported) const
telsoa01c577f2c2018-08-31 09:22:23 +0100368{
arovir01085f0a42018-10-08 14:48:19 +0100369 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
370 reasonIfUnsupported,
371 input,
372 outputStateIn,
373 cellStateIn,
374 scratchBuffer,
375 outputStateOut,
376 cellStateOut,
377 output,
378 descriptor,
379 inputToForgetWeights,
380 inputToCellWeights,
381 inputToOutputWeights,
382 recurrentToForgetWeights,
383 recurrentToCellWeights,
384 recurrentToOutputWeights,
385 forgetGateBias,
386 cellBias,
387 outputGateBias,
388 inputToInputWeights,
389 recurrentToInputWeights,
390 cellToInputWeights,
391 inputGateBias,
392 projectionWeights,
393 projectionBias,
394 cellToForgetWeights,
395 cellToOutputWeights);
telsoa01c577f2c2018-08-31 09:22:23 +0100396}
397
keidav01a959ee52018-12-19 10:04:58 +0000398bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0,
399 const TensorInfo& input1,
400 const TensorInfo& output,
401 Optional<std::string&> reasonIfUnsupported) const
402{
403 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMaximumWorkloadValidate,
404 reasonIfUnsupported,
405 input0,
406 input1,
407 output);
408}
409
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100410bool ClLayerSupport::IsMeanSupported(const TensorInfo& input,
411 const TensorInfo& output,
412 const MeanDescriptor& descriptor,
413 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100414{
Matteo Martincigh28dcab62018-10-19 16:40:03 +0100415 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMeanValidate,
416 reasonIfUnsupported,
417 input,
418 output,
419 descriptor);
narpra0132b90462018-09-13 11:07:48 +0100420}
421
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000422bool ClLayerSupport::IsMemCopySupported(const TensorInfo &input,
423 const TensorInfo &output,
424 Optional<std::string &> reasonIfUnsupported) const
425{
426 ignore_unused(input);
427 ignore_unused(output);
428 return true;
429}
430
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100431bool ClLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000432 const TensorInfo& output,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100433 const OriginsDescriptor& descriptor,
434 Optional<std::string&> reasonIfUnsupported) const
435{
Derek Lamberti0790dce2019-04-15 18:37:35 +0100436 if (descriptor.GetNumDimensions() <= descriptor.GetConcatAxis())
437 {
438 SetValueChecked(reasonIfUnsupported, "Cl Merger: Concat axis > Number of dimensions.");
439 return false;
440 }
441
442 unsigned int concatInnerAxis = (descriptor.GetNumDimensions() - descriptor.GetConcatAxis()) - 1;
443 if(concatInnerAxis < 3) // Width, height, or channels
Nikhil Raj8599a412018-11-19 14:51:07 +0000444 {
445 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMergerWorkloadValidate,
446 reasonIfUnsupported,
447 inputs,
448 output,
449 descriptor);
450 }
Derek Lamberti0790dce2019-04-15 18:37:35 +0100451 else if (concatInnerAxis == 3)
Nikhil Raj8599a412018-11-19 14:51:07 +0000452 {
Derek Lamberti0790dce2019-04-15 18:37:35 +0100453 // We rely on the sub-tensor optimization to handle the batch dimension for 4D tensors. If we can't use
454 // sub-tensors for this then we can't support it. Here is where we check that the sub-tensors will work.
455 for (auto& input : inputs)
456 {
457 if (input && !output.IsTypeSpaceMatch(*input)) // Cannot use sub-tensors if the types are not same space
458 {
459 SetValueChecked(reasonIfUnsupported, "Cl Merger: Types and quantization parameters must match.");
460 return false;
461 }
462 }
463 return true; // Sub-tensors support concat along batch
464 }
465 else // > 4 dimensions not supported.
466 {
467 SetValueChecked(reasonIfUnsupported, "Cl Merger: Maximum of 4 dimensions supported.");
468 return false;
Nikhil Raj8599a412018-11-19 14:51:07 +0000469 }
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100470}
471
saoste019292aa32019-01-08 13:55:59 +0000472bool ClLayerSupport::IsMinimumSupported(const TensorInfo& input0,
473 const TensorInfo& input1,
474 const TensorInfo& output,
475 Optional<std::string&> reasonIfUnsupported) const
476{
477 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMinimumWorkloadValidate,
478 reasonIfUnsupported,
479 input0,
480 input1,
481 output);
482}
483
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100484bool ClLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
485 const TensorInfo& input1,
486 const TensorInfo& output,
487 Optional<std::string&> reasonIfUnsupported) const
488{
489 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
490 reasonIfUnsupported,
491 input0,
492 input1,
493 output);
494}
495
496bool ClLayerSupport::IsNormalizationSupported(const TensorInfo& input,
497 const TensorInfo& output,
498 const NormalizationDescriptor& descriptor,
499 Optional<std::string&> reasonIfUnsupported) const
500{
501 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
502}
503
504bool ClLayerSupport::IsOutputSupported(const TensorInfo& output,
505 Optional<std::string&> reasonIfUnsupported) const
506{
kevmay012b4d88e2019-01-24 14:05:09 +0000507 return IsClBackendSupported(reasonIfUnsupported) &&
508 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
509 output.GetDataType(),
510 &TrueFunc<>,
511 &TrueFunc<>,
512 &TrueFunc<>,
513 &FalseFuncI32<>,
514 &TrueFunc<>);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100515}
516
517bool ClLayerSupport::IsPadSupported(const TensorInfo& input,
518 const TensorInfo& output,
519 const PadDescriptor& descriptor,
520 Optional<std::string&> reasonIfUnsupported) const
arovir01085f0a42018-10-08 14:48:19 +0100521{
522 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate,
523 reasonIfUnsupported,
524 input,
525 output,
526 descriptor);
527}
528
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100529bool ClLayerSupport::IsPermuteSupported(const TensorInfo& input,
530 const TensorInfo& output,
531 const PermuteDescriptor& descriptor,
532 Optional<std::string&> reasonIfUnsupported) const
533{
534 ignore_unused(input);
535 ignore_unused(output);
536 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000537}
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100538
539bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
540 const TensorInfo& output,
541 const Pooling2dDescriptor& descriptor,
542 Optional<std::string&> reasonIfUnsupported) const
543{
544 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
545}
546
547bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000548 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100549 Optional<std::string&> reasonIfUnsupported) const
550{
551 ignore_unused(input);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000552 ignore_unused(descriptor);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100553 ignore_unused(reasonIfUnsupported);
554 return true;
555}
556
557bool ClLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000558 const TensorInfo& output,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100559 Optional<std::string&> reasonIfUnsupported) const
560{
Sadik Armaganc625f002018-12-17 11:32:16 +0000561 ignore_unused(output);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100562 return IsSupportedForDataTypeCl(reasonIfUnsupported,
563 input.GetDataType(),
564 &TrueFunc<>,
565 &FalseFuncU8<>);
566}
567
568bool ClLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
569 const TensorInfo& output,
570 const SoftmaxDescriptor& descriptor,
571 Optional<std::string&> reasonIfUnsupported) const
572{
573 ignore_unused(descriptor);
574 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
575}
576
Sadik Armaganf4464322018-12-20 16:19:12 +0000577bool ClLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
578 const TensorInfo& output,
579 const SpaceToBatchNdDescriptor& descriptor,
580 Optional<std::string&> reasonIfUnsupported) const
581{
582 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSpaceToBatchNdWorkloadValidate,
583 reasonIfUnsupported,
584 input,
585 output,
586 descriptor);
587}
588
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100589bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
590 const ViewsDescriptor& descriptor,
591 Optional<std::string&> reasonIfUnsupported) const
592{
593 ignore_unused(descriptor);
594 return IsSupportedForDataTypeCl(reasonIfUnsupported,
595 input.GetDataType(),
596 &TrueFunc<>,
597 &TrueFunc<>);
598}
599
keidav01d74dc912018-12-10 18:16:07 +0000600bool ClLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
601 const TensorInfo& output,
602 const StridedSliceDescriptor& descriptor,
603 Optional<std::string&> reasonIfUnsupported) const
604{
605 FORWARD_WORKLOAD_VALIDATE_FUNC(ClStridedSliceWorkloadValidate,
606 reasonIfUnsupported,
607 input,
608 output,
609 descriptor);
610}
611
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100612bool ClLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
613 const TensorInfo& input1,
614 const TensorInfo& output,
615 Optional<std::string&> reasonIfUnsupported) const
616{
617 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
618 reasonIfUnsupported,
619 input0,
620 input1,
621 output);
622}
623
624} // namespace armnn