blob: 68563944b4173ec39e8ad537a40eed09bc1ac352 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "ClLayerSupport.hpp"
arovir017c22c702018-10-09 11:16:46 +01007
telsoa014fcda012018-03-09 14:13:49 +00008#include "InternalTypes.hpp"
arovir017c22c702018-10-09 11:16:46 +01009#include "LayerSupportCommon.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010
11#include <boost/core/ignore_unused.hpp>
12
13#ifdef ARMCOMPUTECL_ENABLED
David Beckac42efd2018-09-26 17:41:13 +010014#include "workloads/ClAdditionWorkload.hpp"
15#include "workloads/ClActivationFloatWorkload.hpp"
16#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
17#include "workloads/ClConvertFp16ToFp32Workload.hpp"
18#include "workloads/ClConvertFp32ToFp16Workload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010019#include "workloads/ClConvolution2dWorkload.hpp"
Matthew Benthamd8777392018-10-08 09:38:55 +010020#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010021#include "workloads/ClDivisionFloatWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010022#include "workloads/ClFullyConnectedWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010023#include "workloads/ClL2NormalizationFloatWorkload.hpp"
24#include "workloads/ClLstmFloatWorkload.hpp"
25#include "workloads/ClMultiplicationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010026#include "workloads/ClNormalizationFloatWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010027#include "workloads/ClPadWorkload.hpp"
28#include "workloads/ClPermuteWorkload.hpp"
29#include "workloads/ClPooling2dBaseWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010030#include "workloads/ClSoftmaxBaseWorkload.hpp"
31#include "workloads/ClSubtractionWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000032#endif
33
34using namespace boost;
35
36namespace armnn
37{
arovir017c22c702018-10-09 11:16:46 +010038
39bool ClLayerSupport::IsActivationSupported(const TensorInfo& input,
40 const TensorInfo& output,
41 const ActivationDescriptor& descriptor,
42 Optional<std::string&> reasonIfUnsupported) const
43{
44 return armnn::IsActivationSupportedCl(input, output, descriptor, reasonIfUnsupported);
45}
46
47bool ClLayerSupport::IsAdditionSupported(const TensorInfo& input0,
48 const TensorInfo& input1,
49 const TensorInfo& output,
50 Optional<std::string&> reasonIfUnsupported) const
51{
52 return armnn::IsAdditionSupportedCl(input0, input1, output, reasonIfUnsupported);
53}
54
55bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
56 const TensorInfo& output,
57 const TensorInfo& mean,
58 const TensorInfo& var,
59 const TensorInfo& beta,
60 const TensorInfo& gamma,
61 const BatchNormalizationDescriptor& descriptor,
62 Optional<std::string&> reasonIfUnsupported) const
63{
64 return armnn::IsBatchNormalizationSupportedCl(input,
65 output,
66 mean,
67 var,
68 beta,
69 gamma,
70 descriptor,
71 reasonIfUnsupported);
72}
73
74bool ClLayerSupport::IsConstantSupported(const TensorInfo& output,
75 Optional<std::string&> reasonIfUnsupported) const
76{
77 return armnn::IsConstantSupportedCl(output, reasonIfUnsupported);
78}
79
80bool ClLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
81 const TensorInfo& output,
82 Optional<std::string&> reasonIfUnsupported) const
83{
84 return armnn::IsConvertFp16ToFp32SupportedCl(input, output, reasonIfUnsupported);
85}
86
87bool ClLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
88 const TensorInfo& output,
89 Optional<std::string&> reasonIfUnsupported) const
90{
91 return armnn::IsConvertFp32ToFp16SupportedCl(input, output, reasonIfUnsupported);
92}
93
94bool ClLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
95 const TensorInfo& output,
96 const Convolution2dDescriptor& descriptor,
97 const TensorInfo& weights,
98 const Optional<TensorInfo>& biases,
99 Optional<std::string&> reasonIfUnsupported) const
100{
101 return armnn::IsConvolution2dSupportedCl(input,
102 output,
103 descriptor,
104 weights,
105 biases,
106 reasonIfUnsupported);
107}
108
109bool ClLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
110 const TensorInfo& output,
111 const DepthwiseConvolution2dDescriptor& descriptor,
112 const TensorInfo& weights,
113 const Optional<TensorInfo>& biases,
114 Optional<std::string&> reasonIfUnsupported) const
115{
116 return armnn::IsDepthwiseConvolutionSupportedCl(input,
117 output,
118 descriptor,
119 weights,
120 biases,
121 reasonIfUnsupported);
122}
123
124bool ClLayerSupport::IsDivisionSupported(const TensorInfo& input0,
125 const TensorInfo& input1,
126 const TensorInfo& output,
127 Optional<std::string&> reasonIfUnsupported) const
128{
129 return armnn::IsDivisionSupportedCl(input0, input1, output, reasonIfUnsupported);
130}
131
132bool ClLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
133 const FakeQuantizationDescriptor& descriptor,
134 Optional<std::string&> reasonIfUnsupported) const
135{
136 return armnn::IsFakeQuantizationSupportedCl(input, descriptor, reasonIfUnsupported);
137}
138
139bool ClLayerSupport::IsFloorSupported(const TensorInfo& input,
140 const TensorInfo& output,
141 Optional<std::string&> reasonIfUnsupported) const
142{
143 return armnn::IsFloorSupportedCl(input, output, reasonIfUnsupported);
144}
145
146bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
147 const TensorInfo& output,
148 const TensorInfo& weights,
149 const TensorInfo& biases,
150 const FullyConnectedDescriptor& descriptor,
151 Optional<std::string&> reasonIfUnsupported) const
152{
153 return armnn::IsFullyConnectedSupportedCl(input,
154 output,
155 weights,
156 biases,
157 descriptor,
158 reasonIfUnsupported);
159}
160
161bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
162 Optional<std::string&> reasonIfUnsupported) const
163{
164 return armnn::IsInputSupportedCl(input, reasonIfUnsupported);
165}
166
167bool ClLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
168 const TensorInfo& output,
169 const L2NormalizationDescriptor& descriptor,
170 Optional<std::string&> reasonIfUnsupported) const
171{
172 return armnn::IsL2NormalizationSupportedCl(input, output, descriptor, reasonIfUnsupported);
173}
174
175bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
176 const TensorInfo& outputStateIn,
177 const TensorInfo& cellStateIn,
178 const TensorInfo& scratchBuffer,
179 const TensorInfo& outputStateOut,
180 const TensorInfo& cellStateOut,
181 const TensorInfo& output,
182 const LstmDescriptor& descriptor,
183 const TensorInfo& inputToForgetWeights,
184 const TensorInfo& inputToCellWeights,
185 const TensorInfo& inputToOutputWeights,
186 const TensorInfo& recurrentToForgetWeights,
187 const TensorInfo& recurrentToCellWeights,
188 const TensorInfo& recurrentToOutputWeights,
189 const TensorInfo& forgetGateBias,
190 const TensorInfo& cellBias,
191 const TensorInfo& outputGateBias,
192 const TensorInfo* inputToInputWeights,
193 const TensorInfo* recurrentToInputWeights,
194 const TensorInfo* cellToInputWeights,
195 const TensorInfo* inputGateBias,
196 const TensorInfo* projectionWeights,
197 const TensorInfo* projectionBias,
198 const TensorInfo* cellToForgetWeights,
199 const TensorInfo* cellToOutputWeights,
200 Optional<std::string&> reasonIfUnsupported) const
201{
202 return armnn::IsLstmSupportedCl(input,
203 outputStateIn,
204 cellStateIn,
205 scratchBuffer,
206 outputStateOut,
207 cellStateOut,
208 output,
209 descriptor,
210 inputToForgetWeights,
211 inputToCellWeights,
212 inputToOutputWeights,
213 recurrentToForgetWeights,
214 recurrentToCellWeights,
215 recurrentToOutputWeights,
216 forgetGateBias,
217 cellBias,
218 outputGateBias,
219 inputToInputWeights,
220 recurrentToInputWeights,
221 cellToInputWeights,
222 inputGateBias,
223 projectionWeights,
224 projectionBias,
225 cellToForgetWeights,
226 cellToOutputWeights,
227 reasonIfUnsupported);
228}
229
230bool ClLayerSupport::IsMeanSupported(const TensorInfo& input,
231 const TensorInfo& output,
232 const MeanDescriptor& descriptor,
233 Optional<std::string&> reasonIfUnsupported) const
234{
235 return armnn::IsMeanSupportedCl(input, output, descriptor,reasonIfUnsupported);
236}
237
238bool ClLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
239 const OriginsDescriptor& descriptor,
240 Optional<std::string&> reasonIfUnsupported) const
241{
242 return armnn::IsMergerSupportedCl(inputs, descriptor, reasonIfUnsupported);
243}
244
245bool ClLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
246 const TensorInfo& input1,
247 const TensorInfo& output,
248 Optional<std::string&> reasonIfUnsupported) const
249{
250 return armnn::IsMultiplicationSupportedCl(input0, input1, output, reasonIfUnsupported);
251}
252
253bool ClLayerSupport::IsNormalizationSupported(const TensorInfo& input,
254 const TensorInfo& output,
255 const NormalizationDescriptor& descriptor,
256 Optional<std::string&> reasonIfUnsupported) const
257{
258 return armnn::IsNormalizationSupportedCl(input,
259 output,
260 descriptor,
261 reasonIfUnsupported);
262}
263
264bool ClLayerSupport::IsOutputSupported(const TensorInfo& output,
265 Optional<std::string&> reasonIfUnsupported) const
266{
267 return armnn::IsOutputSupportedCl(output, reasonIfUnsupported);
268}
269
270bool ClLayerSupport::IsPadSupported(const TensorInfo& input,
271 const TensorInfo& output,
272 const PadDescriptor& descriptor,
273 Optional<std::string&> reasonIfUnsupported) const
274{
275 return armnn::IsPadSupportedCl(input, output, descriptor, reasonIfUnsupported);
276}
277
278bool ClLayerSupport::IsPermuteSupported(const TensorInfo& input,
279 const TensorInfo& output,
280 const PermuteDescriptor& descriptor,
281 Optional<std::string&> reasonIfUnsupported) const
282{
283 return armnn::IsPermuteSupportedCl(input, output, descriptor, reasonIfUnsupported);
284}
285
286bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
287 const TensorInfo& output,
288 const Pooling2dDescriptor& descriptor,
289 Optional<std::string&> reasonIfUnsupported) const
290{
291 return armnn::IsPooling2dSupportedCl(input, output, descriptor, reasonIfUnsupported);
292}
293
294bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
295 Optional<std::string&> reasonIfUnsupported) const
296{
297 return armnn::IsReshapeSupportedCl(input, reasonIfUnsupported);
298}
299
300bool ClLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
301 Optional<std::string&> reasonIfUnsupported) const
302{
303 return armnn::IsResizeBilinearSupportedCl(input, reasonIfUnsupported);
304}
305
306bool ClLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
307 const TensorInfo& output,
308 const SoftmaxDescriptor& descriptor,
309 Optional<std::string&> reasonIfUnsupported) const
310{
311 return armnn::IsSoftmaxSupportedCl(input, output, descriptor, reasonIfUnsupported);
312}
313
314bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
315 const ViewsDescriptor& descriptor,
316 Optional<std::string&> reasonIfUnsupported) const
317{
318 return armnn::IsSplitterSupportedCl(input, descriptor, reasonIfUnsupported);
319}
320
321bool ClLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
322 const TensorInfo& input1,
323 const TensorInfo& output,
324 Optional<std::string&> reasonIfUnsupported) const
325{
326 return armnn::IsSubtractionSupportedCl(input0, input1, output, reasonIfUnsupported);
327}
328
329//
330// Implementation functions
331//
332// TODO: Functions kept for backward compatibility. Remove redundant functions
333// once transition to plugable backends is complete.
334
telsoa014fcda012018-03-09 14:13:49 +0000335namespace
336{
337template<unsigned int FilterSize>
338bool IsMatchingSize2d(const TensorInfo& weightInfo)
339{
telsoa01c577f2c2018-08-31 09:22:23 +0100340 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +0000341 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
342}
343
344template<uint32_t ValidStride>
345bool IsMatchingStride(uint32_t actualStride)
346{
347 return ValidStride == actualStride;
348}
349
350template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
351bool IsMatchingStride(uint32_t actualStride)
352{
353 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
354};
355
arovir01085f0a42018-10-08 14:48:19 +0100356bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000357{
358#if ARMCOMPUTECL_ENABLED
359 return true;
360#else
arovir01085f0a42018-10-08 14:48:19 +0100361 if (reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000362 {
arovir01085f0a42018-10-08 14:48:19 +0100363 reasonIfUnsupported.value() = "The armnn library has been built without CL support";
telsoa014fcda012018-03-09 14:13:49 +0000364 }
365 return false;
366#endif
367}
368
369#if ARMCOMPUTECL_ENABLED
370#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
371#else
372#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
373#endif
374
375#if ARMCOMPUTECL_ENABLED
376template<class FuncType, class... Args>
arovir01085f0a42018-10-08 14:48:19 +0100377inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
telsoa014fcda012018-03-09 14:13:49 +0000378{
379 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
380 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
381 if (!supported && reasonIfUnsupported)
382 {
arovir01085f0a42018-10-08 14:48:19 +0100383 reasonIfUnsupported.value() = aclStatus.error_description();
telsoa014fcda012018-03-09 14:13:49 +0000384 }
385 return supported;
386}
387
388#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
389 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
390#else
391#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
392 return IsClBackendSupported(reasonIfUnsupported);
393#endif
394
395} //namespace
396
telsoa01c577f2c2018-08-31 09:22:23 +0100397template<typename FloatFunc, typename Uint8Func, typename ... Params>
arovir01085f0a42018-10-08 14:48:19 +0100398bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000399 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100400 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000401 Uint8Func uint8FuncPtr,
402 Params&&... params)
403{
404 return IsClBackendSupported(reasonIfUnsupported) &&
405 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
406 dataType,
407 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100408 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000409 uint8FuncPtr,
410 std::forward<Params>(params)...);
411}
412
413bool IsActivationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100414 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000415 const ActivationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100416 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000417{
telsoa01c577f2c2018-08-31 09:22:23 +0100418 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
419 reasonIfUnsupported,
420 input,
421 output,
422 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000423}
424
425bool IsAdditionSupportedCl(const TensorInfo& input0,
426 const TensorInfo& input1,
427 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100428 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000429{
arovir01085f0a42018-10-08 14:48:19 +0100430 FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
431 reasonIfUnsupported,
432 input0,
433 input1,
434 output);
telsoa014fcda012018-03-09 14:13:49 +0000435}
436
437bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100438 const TensorInfo& output,
439 const TensorInfo& mean,
440 const TensorInfo& var,
441 const TensorInfo& beta,
442 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +0000443 const BatchNormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100444 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000445{
telsoa01c577f2c2018-08-31 09:22:23 +0100446 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
447 reasonIfUnsupported,
448 input,
449 output,
450 mean,
451 var,
452 beta,
453 gamma,
454 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000455}
456
457bool IsConstantSupportedCl(const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100458 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000459{
460 return IsSupportedForDataTypeCl(reasonIfUnsupported,
461 output.GetDataType(),
462 &TrueFunc<>,
463 &FalseFuncU8<>);
464}
465
466bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc)
467{
468 bool isSupported = false;
469
470 bool strideXIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideX);
471 bool strideXIsThree = IsMatchingStride<3>(desc.m_StrideX);
472
473 bool strideYIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideY);
474 bool strideYIsThree = IsMatchingStride<3>(desc.m_StrideY);
475
476 bool strideIsOneOrTwo = strideXIsOneOrTwo && strideYIsOneOrTwo;
477 bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
478
telsoa01c577f2c2018-08-31 09:22:23 +0100479 // 1x1 convolution with strides of 1,2,3.
telsoa014fcda012018-03-09 14:13:49 +0000480 isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
481
telsoa01c577f2c2018-08-31 09:22:23 +0100482 // 3x3 convolution with strides of 1,2.
telsoa014fcda012018-03-09 14:13:49 +0000483 isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
484
485 // 5x5 convolution with strides of 1,2
486 isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
487
telsoa01c577f2c2018-08-31 09:22:23 +0100488 //Fall back to normal convolution for the asymmetric padding case.
telsoa014fcda012018-03-09 14:13:49 +0000489 if (desc.m_PadLeft != desc.m_PadRight ||
490 desc.m_PadTop != desc.m_PadBottom)
491 {
telsoa01c577f2c2018-08-31 09:22:23 +0100492 //Direct convolution does not support asymmetric padding yet.
telsoa014fcda012018-03-09 14:13:49 +0000493 isSupported = false;
494 }
495
496 return isSupported;
497}
498
arovir01085f0a42018-10-08 14:48:19 +0100499bool IsDirectConvolution2dParamsSupportedCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000500 const Convolution2dDescriptor& parameters,
501 const TensorInfo& weightInfo)
502{
arovir01085f0a42018-10-08 14:48:19 +0100503 ignore_unused(reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000504 return IsClDirectConvolution2dSupported(weightInfo, parameters);
505}
506
507bool IsConvolution2dSupportedCl(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +0100508 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000509 const Convolution2dDescriptor& descriptor,
510 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100511 const Optional<TensorInfo>& biases,
arovir01085f0a42018-10-08 14:48:19 +0100512 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000513{
surmeh013537c2c2018-05-18 16:31:43 +0100514 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
515 reasonIfUnsupported,
516 input,
517 output,
518 descriptor,
519 weights,
520 biases);
telsoa014fcda012018-03-09 14:13:49 +0000521}
522
523bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100524 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000525 const DepthwiseConvolution2dDescriptor& descriptor,
526 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100527 const Optional<TensorInfo>& biases,
arovir01085f0a42018-10-08 14:48:19 +0100528 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000529{
telsoa01c577f2c2018-08-31 09:22:23 +0100530 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
531 reasonIfUnsupported,
532 input,
533 output,
534 descriptor,
535 weights,
536 biases);
telsoa014fcda012018-03-09 14:13:49 +0000537}
538
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100539bool IsDivisionSupportedCl(const TensorInfo& input0,
540 const TensorInfo& input1,
541 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100542 Optional<std::string&> reasonIfUnsupported)
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100543{
544 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
545 reasonIfUnsupported,
546 input0,
547 input1,
548 output);
549}
550
David Beckc2044fe2018-09-05 15:00:38 +0100551bool IsSubtractionSupportedCl(const TensorInfo& input0,
552 const TensorInfo& input1,
553 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100554 Optional<std::string&> reasonIfUnsupported)
David Beckc2044fe2018-09-05 15:00:38 +0100555{
arovir01085f0a42018-10-08 14:48:19 +0100556
557 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
558 reasonIfUnsupported,
559 input0,
560 input1,
561 output);
David Beckc2044fe2018-09-05 15:00:38 +0100562}
563
telsoa014fcda012018-03-09 14:13:49 +0000564bool IsFullyConnectedSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100565 const TensorInfo& output,
566 const TensorInfo& weights,
567 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000568 const FullyConnectedDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100569 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000570{
telsoa01c577f2c2018-08-31 09:22:23 +0100571 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
572 reasonIfUnsupported,
573 input,
574 output,
575 weights,
576 biases,
577 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000578}
579
580bool IsInputSupportedCl(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100581 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000582{
583 return IsSupportedForDataTypeCl(reasonIfUnsupported,
584 input.GetDataType(),
585 &TrueFunc<>,
586 &TrueFunc<>);
587}
588
589bool IsL2NormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100590 const TensorInfo& output,
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100591 const L2NormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100592 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000593{
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100594 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000595}
596
597bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
598 const OriginsDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100599 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000600{
601 ignore_unused(descriptor);
602 return IsSupportedForDataTypeCl(reasonIfUnsupported,
603 inputs[0]->GetDataType(),
604 &TrueFunc<>,
605 &FalseFuncU8<>);
606}
607
608bool IsMultiplicationSupportedCl(const TensorInfo& input0,
609 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100610 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100611 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000612{
telsoa01c577f2c2018-08-31 09:22:23 +0100613 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
614 reasonIfUnsupported,
615 input0,
616 input1,
617 output);
telsoa014fcda012018-03-09 14:13:49 +0000618}
619
620bool IsNormalizationSupportedCl(const TensorInfo& input,
621 const TensorInfo& output,
622 const NormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100623 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000624{
625 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
626}
627
628bool IsOutputSupportedCl(const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100629 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000630{
631 return IsSupportedForDataTypeCl(reasonIfUnsupported,
632 output.GetDataType(),
633 &TrueFunc<>,
634 &TrueFunc<>);
635}
636
637bool IsPermuteSupportedCl(const TensorInfo& input,
638 const TensorInfo& output,
639 const PermuteDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100640 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000641{
642 ignore_unused(input);
643 ignore_unused(output);
644 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
645}
646
647bool IsPooling2dSupportedCl(const TensorInfo& input,
648 const TensorInfo& output,
649 const Pooling2dDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100650 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000651{
652 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
653}
654
655bool IsResizeBilinearSupportedCl(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100656 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000657{
658 return IsSupportedForDataTypeCl(reasonIfUnsupported,
659 input.GetDataType(),
660 &TrueFunc<>,
661 &FalseFuncU8<>);
662}
663
664bool IsSoftmaxSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100665 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000666 const SoftmaxDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100667 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000668{
669 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100670 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
telsoa014fcda012018-03-09 14:13:49 +0000671}
672
673bool IsSplitterSupportedCl(const TensorInfo& input,
674 const ViewsDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100675 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000676{
677 ignore_unused(descriptor);
678 return IsSupportedForDataTypeCl(reasonIfUnsupported,
679 input.GetDataType(),
680 &TrueFunc<>,
681 &TrueFunc<>);
682}
683
684bool IsFakeQuantizationSupportedCl(const TensorInfo& input,
685 const FakeQuantizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100686 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000687{
688 ignore_unused(input);
689 ignore_unused(descriptor);
arovir01085f0a42018-10-08 14:48:19 +0100690 ignore_unused(reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000691 return false;
692}
693
694bool IsReshapeSupportedCl(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100695 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000696{
697 ignore_unused(input);
arovir01085f0a42018-10-08 14:48:19 +0100698 ignore_unused(reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000699 return true;
700}
701
702bool IsFloorSupportedCl(const TensorInfo& input,
703 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100704 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000705{
706 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100707 return IsClBackendSupported(reasonIfUnsupported) &&
708 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
709 input.GetDataType(),
710 &FalseFuncF16<>,
711 &TrueFunc<>,
712 &FalseFuncU8<>);
713}
714
arovir01085f0a42018-10-08 14:48:19 +0100715bool IsLstmSupportedCl(const TensorInfo& input,
716 const TensorInfo& outputStateIn,
717 const TensorInfo& cellStateIn,
718 const TensorInfo& scratchBuffer,
719 const TensorInfo& outputStateOut,
720 const TensorInfo& cellStateOut,
721 const TensorInfo& output,
722 const LstmDescriptor& descriptor,
723 const TensorInfo& inputToForgetWeights,
724 const TensorInfo& inputToCellWeights,
725 const TensorInfo& inputToOutputWeights,
726 const TensorInfo& recurrentToForgetWeights,
727 const TensorInfo& recurrentToCellWeights,
728 const TensorInfo& recurrentToOutputWeights,
729 const TensorInfo& forgetGateBias,
730 const TensorInfo& cellBias,
731 const TensorInfo& outputGateBias,
732 const TensorInfo* inputToInputWeights,
733 const TensorInfo* recurrentToInputWeights,
734 const TensorInfo* cellToInputWeights,
735 const TensorInfo* inputGateBias,
736 const TensorInfo* projectionWeights,
737 const TensorInfo* projectionBias,
738 const TensorInfo* cellToForgetWeights,
739 const TensorInfo* cellToOutputWeights,
740 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100741{
arovir01085f0a42018-10-08 14:48:19 +0100742 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
743 reasonIfUnsupported,
744 input,
745 outputStateIn,
746 cellStateIn,
747 scratchBuffer,
748 outputStateOut,
749 cellStateOut,
750 output,
751 descriptor,
752 inputToForgetWeights,
753 inputToCellWeights,
754 inputToOutputWeights,
755 recurrentToForgetWeights,
756 recurrentToCellWeights,
757 recurrentToOutputWeights,
758 forgetGateBias,
759 cellBias,
760 outputGateBias,
761 inputToInputWeights,
762 recurrentToInputWeights,
763 cellToInputWeights,
764 inputGateBias,
765 projectionWeights,
766 projectionBias,
767 cellToForgetWeights,
768 cellToOutputWeights);
telsoa01c577f2c2018-08-31 09:22:23 +0100769}
770
771bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input,
772 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100773 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100774{
775 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
776 reasonIfUnsupported,
777 input,
arovir01085f0a42018-10-08 14:48:19 +0100778 output);
telsoa01c577f2c2018-08-31 09:22:23 +0100779}
780
781bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input,
782 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100783 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100784{
785 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
786 reasonIfUnsupported,
787 input,
arovir01085f0a42018-10-08 14:48:19 +0100788 output);
telsoa014fcda012018-03-09 14:13:49 +0000789}
790
narpra0132b90462018-09-13 11:07:48 +0100791bool IsMeanSupportedCl(const TensorInfo& input,
792 const TensorInfo& output,
793 const MeanDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100794 Optional<std::string&> reasonIfUnsupported)
narpra0132b90462018-09-13 11:07:48 +0100795{
arovir01085f0a42018-10-08 14:48:19 +0100796 ignore_unused(input);
797 ignore_unused(output);
798 ignore_unused(descriptor);
799 ignore_unused(reasonIfUnsupported);
narpra0132b90462018-09-13 11:07:48 +0100800 return false;
801}
802
arovir01085f0a42018-10-08 14:48:19 +0100803bool IsPadSupportedCl(const TensorInfo& input,
804 const TensorInfo& output,
805 const PadDescriptor& descriptor,
806 Optional<std::string&> reasonIfUnsupported)
807{
808 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate,
809 reasonIfUnsupported,
810 input,
811 output,
812 descriptor);
813}
814
telsoa014fcda012018-03-09 14:13:49 +0000815}