blob: d79f6126a79d76c678b1d7df9dae6d4f5d548ff6 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "ClLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "ClBackendId.hpp"
arovir017c22c702018-10-09 11:16:46 +01008
David Beck3cc9a622018-10-12 10:38:31 +01009#include <armnn/Descriptors.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
David Beck111b5d92018-11-12 14:59:37 +000013#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010014
telsoa014fcda012018-03-09 14:13:49 +000015#include <boost/core/ignore_unused.hpp>
16
Matteo Martincighd95e9062019-01-31 15:35:59 +000017#if defined(ARMCOMPUTECL_ENABLED)
Narumol Prangnawarat74135832019-05-23 15:07:33 +010018#include <aclCommon/ArmComputeUtils.hpp>
David Beckac42efd2018-09-26 17:41:13 +010019#include "workloads/ClAdditionWorkload.hpp"
Nattapat Chaimanowonge06757e2018-10-11 15:39:18 +010020#include "workloads/ClActivationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010021#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
Mike Kelly831faed2018-11-28 11:52:08 +000022#include "workloads/ClBatchToSpaceNdWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010023#include "workloads/ClConvertFp16ToFp32Workload.hpp"
24#include "workloads/ClConvertFp32ToFp16Workload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010025#include "workloads/ClConvolution2dWorkload.hpp"
Jim Flynn983daec2019-05-29 16:20:16 +010026#include "workloads/ClDequantizeWorkload.hpp"
Matthew Benthamd8777392018-10-08 09:38:55 +010027#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010028#include "workloads/ClDivisionFloatWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010029#include "workloads/ClFullyConnectedWorkload.hpp"
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +000030#include "workloads/ClGreaterWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010031#include "workloads/ClL2NormalizationFloatWorkload.hpp"
32#include "workloads/ClLstmFloatWorkload.hpp"
keidav01a959ee52018-12-19 10:04:58 +000033#include "workloads/ClMaximumWorkload.hpp"
Matteo Martincigh28dcab62018-10-19 16:40:03 +010034#include "workloads/ClMeanWorkload.hpp"
Jim Flynn69059412019-05-17 13:03:57 +010035#include "workloads/ClConcatWorkload.hpp"
saoste019292aa32019-01-08 13:55:59 +000036#include "workloads/ClMinimumWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010037#include "workloads/ClMultiplicationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010038#include "workloads/ClNormalizationFloatWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010039#include "workloads/ClPadWorkload.hpp"
40#include "workloads/ClPermuteWorkload.hpp"
Nattapat Chaimanowongac9e0962018-10-10 17:18:35 +010041#include "workloads/ClPooling2dWorkload.hpp"
Sadik Armagan20ec2492019-05-31 09:09:44 +010042#include "workloads/ClQuantizeWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010043#include "workloads/ClSoftmaxBaseWorkload.hpp"
Sadik Armaganf4464322018-12-20 16:19:12 +000044#include "workloads/ClSpaceToBatchNdWorkload.hpp"
Narumol Prangnawarat74135832019-05-23 15:07:33 +010045#include "workloads/ClSplitterWorkload.hpp"
keidav01d74dc912018-12-10 18:16:07 +000046#include "workloads/ClStridedSliceWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010047#include "workloads/ClSubtractionWorkload.hpp"
Aron Virginas-Tar7a3e2fe2019-06-27 18:54:47 +010048#include "workloads/ClTransposeConvolution2dWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000049#endif
50
51using namespace boost;
52
53namespace armnn
54{
arovir017c22c702018-10-09 11:16:46 +010055
telsoa014fcda012018-03-09 14:13:49 +000056namespace
57{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010058
telsoa014fcda012018-03-09 14:13:49 +000059template<unsigned int FilterSize>
60bool IsMatchingSize2d(const TensorInfo& weightInfo)
61{
telsoa01c577f2c2018-08-31 09:22:23 +010062 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +000063 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
64}
65
66template<uint32_t ValidStride>
67bool IsMatchingStride(uint32_t actualStride)
68{
69 return ValidStride == actualStride;
70}
71
72template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
73bool IsMatchingStride(uint32_t actualStride)
74{
75 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010076}
telsoa014fcda012018-03-09 14:13:49 +000077
arovir01085f0a42018-10-08 14:48:19 +010078bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000079{
Matteo Martincighd95e9062019-01-31 15:35:59 +000080#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000081 return true;
82#else
arovir01085f0a42018-10-08 14:48:19 +010083 if (reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000084 {
arovir01085f0a42018-10-08 14:48:19 +010085 reasonIfUnsupported.value() = "The armnn library has been built without CL support";
telsoa014fcda012018-03-09 14:13:49 +000086 }
87 return false;
88#endif
89}
90
Matteo Martincighd95e9062019-01-31 15:35:59 +000091#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000092#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
93#else
94#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
95#endif
96
Matteo Martincighd95e9062019-01-31 15:35:59 +000097#if defined(ARMCOMPUTECL_ENABLED)
telsoa014fcda012018-03-09 14:13:49 +000098template<class FuncType, class... Args>
arovir01085f0a42018-10-08 14:48:19 +010099inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
telsoa014fcda012018-03-09 14:13:49 +0000100{
101 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
102 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
103 if (!supported && reasonIfUnsupported)
104 {
arovir01085f0a42018-10-08 14:48:19 +0100105 reasonIfUnsupported.value() = aclStatus.error_description();
telsoa014fcda012018-03-09 14:13:49 +0000106 }
107 return supported;
108}
109
110#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
111 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
112#else
113#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
114 return IsClBackendSupported(reasonIfUnsupported);
115#endif
116
telsoa01c577f2c2018-08-31 09:22:23 +0100117template<typename FloatFunc, typename Uint8Func, typename ... Params>
arovir01085f0a42018-10-08 14:48:19 +0100118bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000119 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100120 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000121 Uint8Func uint8FuncPtr,
122 Params&&... params)
123{
124 return IsClBackendSupported(reasonIfUnsupported) &&
125 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
126 dataType,
127 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100128 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000129 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +0000130 &FalseFunc<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000131 &FalseFunc<>,
telsoa014fcda012018-03-09 14:13:49 +0000132 std::forward<Params>(params)...);
133}
134
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100135} // anonymous namespace
136
137bool ClLayerSupport::IsActivationSupported(const TensorInfo& input,
138 const TensorInfo& output,
139 const ActivationDescriptor& descriptor,
140 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000141{
telsoa01c577f2c2018-08-31 09:22:23 +0100142 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
143 reasonIfUnsupported,
144 input,
145 output,
146 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000147}
148
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100149bool ClLayerSupport::IsAdditionSupported(const TensorInfo& input0,
150 const TensorInfo& input1,
151 const TensorInfo& output,
152 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000153{
arovir01085f0a42018-10-08 14:48:19 +0100154 FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
155 reasonIfUnsupported,
156 input0,
157 input1,
158 output);
telsoa014fcda012018-03-09 14:13:49 +0000159}
160
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100161bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
162 const TensorInfo& output,
163 const TensorInfo& mean,
164 const TensorInfo& var,
165 const TensorInfo& beta,
166 const TensorInfo& gamma,
167 const BatchNormalizationDescriptor& descriptor,
168 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000169{
telsoa01c577f2c2018-08-31 09:22:23 +0100170 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
171 reasonIfUnsupported,
172 input,
173 output,
174 mean,
175 var,
176 beta,
177 gamma,
178 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000179}
180
Mike Kelly831faed2018-11-28 11:52:08 +0000181bool ClLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
182 const TensorInfo& output,
183 const BatchToSpaceNdDescriptor& descriptor,
184 Optional<std::string&> reasonIfUnsupported) const
185{
186 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchToSpaceNdWorkloadValidate,
187 reasonIfUnsupported,
188 input,
189 output,
190 descriptor);
191}
192
Jim Flynn906f9462019-05-10 13:55:21 +0100193bool ClLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
194 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100195 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100196 Optional<std::string&> reasonIfUnsupported) const
197{
Jim Flynne242f2d2019-05-22 14:24:13 +0100198 if (descriptor.GetNumDimensions() <= descriptor.GetConcatAxis())
199 {
200 SetValueChecked(reasonIfUnsupported, "Cl Concat: Concat axis > Number of dimensions.");
201 return false;
202 }
203
204 unsigned int concatInnerAxis = (descriptor.GetNumDimensions() - descriptor.GetConcatAxis()) - 1;
205 if(concatInnerAxis < 3) // Width, height, or channels
206 {
207 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConcatWorkloadValidate,
208 reasonIfUnsupported,
209 inputs,
210 output,
211 descriptor);
212 }
213 else if (concatInnerAxis == 3)
214 {
215 // We rely on the sub-tensor optimization to handle the batch dimension for 4D tensors. If we can't use
216 // sub-tensors for this then we can't support it. Here is where we check that the sub-tensors will work.
217 for (auto& input : inputs)
218 {
219 if (input && !output.IsTypeSpaceMatch(*input)) // Cannot use sub-tensors if the types are not same space
220 {
221 SetValueChecked(reasonIfUnsupported, "Cl Concat: Types and quantization parameters must match.");
222 return false;
223 }
224 }
225 return true; // Sub-tensors support concat along batch
226 }
227 else // > 4 dimensions not supported.
228 {
229 SetValueChecked(reasonIfUnsupported, "Cl Concat: Maximum of 4 dimensions supported.");
230 return false;
231 }
Jim Flynn906f9462019-05-10 13:55:21 +0100232}
233
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100234bool ClLayerSupport::IsConstantSupported(const TensorInfo& output,
235 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000236{
237 return IsSupportedForDataTypeCl(reasonIfUnsupported,
238 output.GetDataType(),
239 &TrueFunc<>,
240 &FalseFuncU8<>);
241}
242
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100243bool ClLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
244 const TensorInfo& output,
245 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000246{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100247 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
248 reasonIfUnsupported,
249 input,
250 output);
telsoa014fcda012018-03-09 14:13:49 +0000251}
252
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100253bool ClLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
254 const TensorInfo& output,
255 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000256{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100257 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
258 reasonIfUnsupported,
259 input,
260 output);
telsoa014fcda012018-03-09 14:13:49 +0000261}
262
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100263bool ClLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
264 const TensorInfo& output,
265 const Convolution2dDescriptor& descriptor,
266 const TensorInfo& weights,
267 const Optional<TensorInfo>& biases,
268 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000269{
surmeh013537c2c2018-05-18 16:31:43 +0100270 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
271 reasonIfUnsupported,
272 input,
273 output,
274 descriptor,
275 weights,
276 biases);
telsoa014fcda012018-03-09 14:13:49 +0000277}
278
Jim Flynn983daec2019-05-29 16:20:16 +0100279bool ClLayerSupport::IsDequantizeSupported(const TensorInfo& input,
280 const TensorInfo& output,
281 Optional<std::string&> reasonIfUnsupported) const
282{
283 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDequantizeWorkloadValidate,
284 reasonIfUnsupported,
285 input,
286 output);
287}
288
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100289bool ClLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
290 const TensorInfo& output,
291 const DepthwiseConvolution2dDescriptor& descriptor,
292 const TensorInfo& weights,
293 const Optional<TensorInfo>& biases,
294 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000295{
telsoa01c577f2c2018-08-31 09:22:23 +0100296 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
297 reasonIfUnsupported,
298 input,
299 output,
300 descriptor,
301 weights,
302 biases);
telsoa014fcda012018-03-09 14:13:49 +0000303}
304
Pablo Tellof0bd6832019-04-26 17:58:13 +0100305bool ClLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
306 const TensorInfo& output,
307 const DepthwiseConvolution2dDescriptor& descriptor,
308 const TensorInfo& weights,
309 const Optional<TensorInfo>& biases,
310 Optional<std::string&> reasonIfUnsupported) const
311{
312 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
313 reasonIfUnsupported,
314 input,
315 output,
316 descriptor,
317 weights,
318 biases);
319}
320
321
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100322bool ClLayerSupport::IsDivisionSupported(const TensorInfo& input0,
323 const TensorInfo& input1,
324 const TensorInfo& output,
325 Optional<std::string&> reasonIfUnsupported) const
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100326{
327 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
328 reasonIfUnsupported,
329 input0,
330 input1,
331 output);
332}
333
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100334bool ClLayerSupport::IsFloorSupported(const TensorInfo& input,
335 const TensorInfo& output,
336 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000337{
338 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100339 return IsClBackendSupported(reasonIfUnsupported) &&
340 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
341 input.GetDataType(),
342 &FalseFuncF16<>,
343 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000344 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000345 &FalseFuncI32<>,
346 &FalseFuncU8<>);
telsoa01c577f2c2018-08-31 09:22:23 +0100347}
348
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100349bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
350 const TensorInfo& output,
351 const TensorInfo& weights,
352 const TensorInfo& biases,
353 const FullyConnectedDescriptor& descriptor,
354 Optional<std::string&> reasonIfUnsupported) const
355{
356 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
357 reasonIfUnsupported,
358 input,
359 output,
360 weights,
361 biases,
362 descriptor);
363}
364
Nattapat Chaimanowongc6a41ff2019-01-29 09:56:02 +0000365bool ClLayerSupport::IsGreaterSupported(const TensorInfo& input0,
366 const TensorInfo& input1,
367 const TensorInfo& output,
368 Optional<std::string&> reasonIfUnsupported) const
369{
370 FORWARD_WORKLOAD_VALIDATE_FUNC(ClGreaterWorkloadValidate,
371 reasonIfUnsupported,
372 input0,
373 input1,
374 output);
375}
376
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100377bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
378 Optional<std::string&> reasonIfUnsupported) const
379{
380 return IsSupportedForDataTypeCl(reasonIfUnsupported,
381 input.GetDataType(),
382 &TrueFunc<>,
383 &TrueFunc<>);
384}
385
386bool ClLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
387 const TensorInfo& output,
388 const L2NormalizationDescriptor& descriptor,
389 Optional<std::string&> reasonIfUnsupported) const
390{
391 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate,
392 reasonIfUnsupported,
393 input,
394 output,
395 descriptor);
396}
397
398bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
399 const TensorInfo& outputStateIn,
400 const TensorInfo& cellStateIn,
401 const TensorInfo& scratchBuffer,
402 const TensorInfo& outputStateOut,
403 const TensorInfo& cellStateOut,
404 const TensorInfo& output,
405 const LstmDescriptor& descriptor,
406 const TensorInfo& inputToForgetWeights,
407 const TensorInfo& inputToCellWeights,
408 const TensorInfo& inputToOutputWeights,
409 const TensorInfo& recurrentToForgetWeights,
410 const TensorInfo& recurrentToCellWeights,
411 const TensorInfo& recurrentToOutputWeights,
412 const TensorInfo& forgetGateBias,
413 const TensorInfo& cellBias,
414 const TensorInfo& outputGateBias,
415 const TensorInfo* inputToInputWeights,
416 const TensorInfo* recurrentToInputWeights,
417 const TensorInfo* cellToInputWeights,
418 const TensorInfo* inputGateBias,
419 const TensorInfo* projectionWeights,
420 const TensorInfo* projectionBias,
421 const TensorInfo* cellToForgetWeights,
422 const TensorInfo* cellToOutputWeights,
Jan Eilers38e05bd2019-06-26 13:10:09 +0100423 Optional<std::string&> reasonIfUnsupported,
424 const TensorInfo* inputLayerNormWeights,
425 const TensorInfo* forgetLayerNormWeights,
426 const TensorInfo* cellLayerNormWeights,
427 const TensorInfo* outputLayerNormWeights) const
telsoa01c577f2c2018-08-31 09:22:23 +0100428{
arovir01085f0a42018-10-08 14:48:19 +0100429 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
430 reasonIfUnsupported,
431 input,
432 outputStateIn,
433 cellStateIn,
434 scratchBuffer,
435 outputStateOut,
436 cellStateOut,
437 output,
438 descriptor,
439 inputToForgetWeights,
440 inputToCellWeights,
441 inputToOutputWeights,
442 recurrentToForgetWeights,
443 recurrentToCellWeights,
444 recurrentToOutputWeights,
445 forgetGateBias,
446 cellBias,
447 outputGateBias,
448 inputToInputWeights,
449 recurrentToInputWeights,
450 cellToInputWeights,
451 inputGateBias,
452 projectionWeights,
453 projectionBias,
454 cellToForgetWeights,
455 cellToOutputWeights);
telsoa01c577f2c2018-08-31 09:22:23 +0100456}
457
keidav01a959ee52018-12-19 10:04:58 +0000458bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0,
459 const TensorInfo& input1,
460 const TensorInfo& output,
461 Optional<std::string&> reasonIfUnsupported) const
462{
463 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMaximumWorkloadValidate,
464 reasonIfUnsupported,
465 input0,
466 input1,
467 output);
468}
469
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100470bool ClLayerSupport::IsMeanSupported(const TensorInfo& input,
471 const TensorInfo& output,
472 const MeanDescriptor& descriptor,
473 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100474{
Matteo Martincigh28dcab62018-10-19 16:40:03 +0100475 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMeanValidate,
476 reasonIfUnsupported,
477 input,
478 output,
479 descriptor);
narpra0132b90462018-09-13 11:07:48 +0100480}
481
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000482bool ClLayerSupport::IsMemCopySupported(const TensorInfo &input,
483 const TensorInfo &output,
484 Optional<std::string &> reasonIfUnsupported) const
485{
486 ignore_unused(input);
487 ignore_unused(output);
488 return true;
489}
490
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100491bool ClLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000492 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100493 const MergerDescriptor& descriptor,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100494 Optional<std::string&> reasonIfUnsupported) const
495{
Jim Flynne242f2d2019-05-22 14:24:13 +0100496 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100497}
498
saoste019292aa32019-01-08 13:55:59 +0000499bool ClLayerSupport::IsMinimumSupported(const TensorInfo& input0,
500 const TensorInfo& input1,
501 const TensorInfo& output,
502 Optional<std::string&> reasonIfUnsupported) const
503{
504 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMinimumWorkloadValidate,
505 reasonIfUnsupported,
506 input0,
507 input1,
508 output);
509}
510
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100511bool ClLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
512 const TensorInfo& input1,
513 const TensorInfo& output,
514 Optional<std::string&> reasonIfUnsupported) const
515{
516 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
517 reasonIfUnsupported,
518 input0,
519 input1,
520 output);
521}
522
523bool ClLayerSupport::IsNormalizationSupported(const TensorInfo& input,
524 const TensorInfo& output,
525 const NormalizationDescriptor& descriptor,
526 Optional<std::string&> reasonIfUnsupported) const
527{
528 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
529}
530
531bool ClLayerSupport::IsOutputSupported(const TensorInfo& output,
532 Optional<std::string&> reasonIfUnsupported) const
533{
kevmay012b4d88e2019-01-24 14:05:09 +0000534 return IsClBackendSupported(reasonIfUnsupported) &&
535 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
536 output.GetDataType(),
537 &TrueFunc<>,
538 &TrueFunc<>,
539 &TrueFunc<>,
540 &FalseFuncI32<>,
541 &TrueFunc<>);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100542}
543
544bool ClLayerSupport::IsPadSupported(const TensorInfo& input,
545 const TensorInfo& output,
546 const PadDescriptor& descriptor,
547 Optional<std::string&> reasonIfUnsupported) const
arovir01085f0a42018-10-08 14:48:19 +0100548{
549 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate,
550 reasonIfUnsupported,
551 input,
552 output,
553 descriptor);
554}
555
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100556bool ClLayerSupport::IsPermuteSupported(const TensorInfo& input,
557 const TensorInfo& output,
558 const PermuteDescriptor& descriptor,
559 Optional<std::string&> reasonIfUnsupported) const
560{
561 ignore_unused(input);
562 ignore_unused(output);
563 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000564}
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100565
566bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
567 const TensorInfo& output,
568 const Pooling2dDescriptor& descriptor,
569 Optional<std::string&> reasonIfUnsupported) const
570{
571 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
572}
573
Sadik Armagan20ec2492019-05-31 09:09:44 +0100574bool ClLayerSupport::IsQuantizeSupported(const TensorInfo& input,
575 const TensorInfo& output,
576 Optional<std::string&> reasonIfUnsupported) const
577{
578 FORWARD_WORKLOAD_VALIDATE_FUNC(ClQuantizeWorkloadValidate,
579 reasonIfUnsupported,
580 input,
581 output);
582}
583
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100584bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000585 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100586 Optional<std::string&> reasonIfUnsupported) const
587{
588 ignore_unused(input);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000589 ignore_unused(descriptor);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100590 ignore_unused(reasonIfUnsupported);
591 return true;
592}
593
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100594bool ClLayerSupport::IsResizeSupported(const TensorInfo& input,
595 const TensorInfo& output,
596 const ResizeDescriptor& descriptor,
597 Optional<std::string&> reasonIfUnsupported) const
598{
599 ignore_unused(output);
600
601 if (descriptor.m_Method == ResizeMethod::Bilinear)
602 {
603 return IsSupportedForDataTypeCl(reasonIfUnsupported,
604 input.GetDataType(),
605 &TrueFunc<>,
606 &FalseFuncU8<>);
607 }
608
609 return false;
610}
611
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100612bool ClLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +0000613 const TensorInfo& output,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100614 Optional<std::string&> reasonIfUnsupported) const
615{
Sadik Armaganc625f002018-12-17 11:32:16 +0000616 ignore_unused(output);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100617 return IsSupportedForDataTypeCl(reasonIfUnsupported,
618 input.GetDataType(),
619 &TrueFunc<>,
620 &FalseFuncU8<>);
621}
622
623bool ClLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
624 const TensorInfo& output,
625 const SoftmaxDescriptor& descriptor,
626 Optional<std::string&> reasonIfUnsupported) const
627{
628 ignore_unused(descriptor);
629 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
630}
631
Sadik Armaganf4464322018-12-20 16:19:12 +0000632bool ClLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
633 const TensorInfo& output,
634 const SpaceToBatchNdDescriptor& descriptor,
635 Optional<std::string&> reasonIfUnsupported) const
636{
637 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSpaceToBatchNdWorkloadValidate,
638 reasonIfUnsupported,
639 input,
640 output,
641 descriptor);
642}
643
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100644bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
645 const ViewsDescriptor& descriptor,
646 Optional<std::string&> reasonIfUnsupported) const
647{
648 ignore_unused(descriptor);
649 return IsSupportedForDataTypeCl(reasonIfUnsupported,
650 input.GetDataType(),
651 &TrueFunc<>,
652 &TrueFunc<>);
653}
654
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100655bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
656 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
657 const ViewsDescriptor& descriptor,
658 Optional<std::string&> reasonIfUnsupported) const
659{
Narumol Prangnawarat74135832019-05-23 15:07:33 +0100660#if defined(ARMCOMPUTECL_ENABLED)
661 // Split along the last dimension, cannot use sub-tensors
662 // as width and height of the sub-tensors do not match
663 // the width and height of the parent tensor
664 // in case of input with more than 2D.
665 std::set<unsigned int> splitAxis = ComputeSplitAxis(descriptor, input.GetShape());
666 if (descriptor.GetNumDimensions() > 2 && splitAxis.size() == 1 &&
667 *splitAxis.begin() == descriptor.GetNumDimensions() - 1 )
668 {
669 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSplitterWorkloadValidate,
670 reasonIfUnsupported,
671 input,
672 outputs,
673 *splitAxis.begin());
674 }
675#endif
676 for (auto output : outputs)
677 {
678 if (!input.IsTypeSpaceMatch(output)) // Cannot use sub-tensors if the types are not same space
679 {
680 SetValueChecked(reasonIfUnsupported, "Cl Splitter: Types and quantization parameters must match.");
681 return false;
682 }
683 }
684 return true;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100685}
686
keidav01d74dc912018-12-10 18:16:07 +0000687bool ClLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
688 const TensorInfo& output,
689 const StridedSliceDescriptor& descriptor,
690 Optional<std::string&> reasonIfUnsupported) const
691{
692 FORWARD_WORKLOAD_VALIDATE_FUNC(ClStridedSliceWorkloadValidate,
693 reasonIfUnsupported,
694 input,
695 output,
696 descriptor);
697}
698
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100699bool ClLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
700 const TensorInfo& input1,
701 const TensorInfo& output,
702 Optional<std::string&> reasonIfUnsupported) const
703{
704 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
705 reasonIfUnsupported,
706 input0,
707 input1,
708 output);
709}
710
Aron Virginas-Tar7a3e2fe2019-06-27 18:54:47 +0100711bool ClLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
712 const TensorInfo& output,
713 const TransposeConvolution2dDescriptor& descriptor,
714 const TensorInfo& weights,
715 const Optional<TensorInfo>& biases,
716 Optional<std::string&> reasonIfUnsupported) const
717{
718 FORWARD_WORKLOAD_VALIDATE_FUNC(ClTransposeConvolution2dWorkloadValidate,
719 reasonIfUnsupported,
720 input,
721 output,
722 descriptor,
723 weights,
724 biases);
725}
726
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100727} // namespace armnn