blob: 7c66348b98c239f68455aaf5599335a01b18d8e5 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "ClLayerSupport.hpp"
arovir017c22c702018-10-09 11:16:46 +01007
David Beck3cc9a622018-10-12 10:38:31 +01008#include <InternalTypes.hpp>
9#include <LayerSupportCommon.hpp>
10
11#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
13#include <boost/core/ignore_unused.hpp>
14
15#ifdef ARMCOMPUTECL_ENABLED
David Beckac42efd2018-09-26 17:41:13 +010016#include "workloads/ClAdditionWorkload.hpp"
Nattapat Chaimanowonge06757e2018-10-11 15:39:18 +010017#include "workloads/ClActivationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010018#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
19#include "workloads/ClConvertFp16ToFp32Workload.hpp"
20#include "workloads/ClConvertFp32ToFp16Workload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010021#include "workloads/ClConvolution2dWorkload.hpp"
Matthew Benthamd8777392018-10-08 09:38:55 +010022#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010023#include "workloads/ClDivisionFloatWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010024#include "workloads/ClFullyConnectedWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010025#include "workloads/ClL2NormalizationFloatWorkload.hpp"
26#include "workloads/ClLstmFloatWorkload.hpp"
27#include "workloads/ClMultiplicationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010028#include "workloads/ClNormalizationFloatWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010029#include "workloads/ClPadWorkload.hpp"
30#include "workloads/ClPermuteWorkload.hpp"
Nattapat Chaimanowongac9e0962018-10-10 17:18:35 +010031#include "workloads/ClPooling2dWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010032#include "workloads/ClSoftmaxBaseWorkload.hpp"
33#include "workloads/ClSubtractionWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000034#endif
35
36using namespace boost;
37
38namespace armnn
39{
arovir017c22c702018-10-09 11:16:46 +010040
41bool ClLayerSupport::IsActivationSupported(const TensorInfo& input,
42 const TensorInfo& output,
43 const ActivationDescriptor& descriptor,
44 Optional<std::string&> reasonIfUnsupported) const
45{
46 return armnn::IsActivationSupportedCl(input, output, descriptor, reasonIfUnsupported);
47}
48
49bool ClLayerSupport::IsAdditionSupported(const TensorInfo& input0,
50 const TensorInfo& input1,
51 const TensorInfo& output,
52 Optional<std::string&> reasonIfUnsupported) const
53{
54 return armnn::IsAdditionSupportedCl(input0, input1, output, reasonIfUnsupported);
55}
56
57bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
58 const TensorInfo& output,
59 const TensorInfo& mean,
60 const TensorInfo& var,
61 const TensorInfo& beta,
62 const TensorInfo& gamma,
63 const BatchNormalizationDescriptor& descriptor,
64 Optional<std::string&> reasonIfUnsupported) const
65{
66 return armnn::IsBatchNormalizationSupportedCl(input,
67 output,
68 mean,
69 var,
70 beta,
71 gamma,
72 descriptor,
73 reasonIfUnsupported);
74}
75
76bool ClLayerSupport::IsConstantSupported(const TensorInfo& output,
77 Optional<std::string&> reasonIfUnsupported) const
78{
79 return armnn::IsConstantSupportedCl(output, reasonIfUnsupported);
80}
81
82bool ClLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
83 const TensorInfo& output,
84 Optional<std::string&> reasonIfUnsupported) const
85{
86 return armnn::IsConvertFp16ToFp32SupportedCl(input, output, reasonIfUnsupported);
87}
88
89bool ClLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
90 const TensorInfo& output,
91 Optional<std::string&> reasonIfUnsupported) const
92{
93 return armnn::IsConvertFp32ToFp16SupportedCl(input, output, reasonIfUnsupported);
94}
95
96bool ClLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
97 const TensorInfo& output,
98 const Convolution2dDescriptor& descriptor,
99 const TensorInfo& weights,
100 const Optional<TensorInfo>& biases,
101 Optional<std::string&> reasonIfUnsupported) const
102{
103 return armnn::IsConvolution2dSupportedCl(input,
104 output,
105 descriptor,
106 weights,
107 biases,
108 reasonIfUnsupported);
109}
110
111bool ClLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
112 const TensorInfo& output,
113 const DepthwiseConvolution2dDescriptor& descriptor,
114 const TensorInfo& weights,
115 const Optional<TensorInfo>& biases,
116 Optional<std::string&> reasonIfUnsupported) const
117{
118 return armnn::IsDepthwiseConvolutionSupportedCl(input,
119 output,
120 descriptor,
121 weights,
122 biases,
123 reasonIfUnsupported);
124}
125
126bool ClLayerSupport::IsDivisionSupported(const TensorInfo& input0,
127 const TensorInfo& input1,
128 const TensorInfo& output,
129 Optional<std::string&> reasonIfUnsupported) const
130{
131 return armnn::IsDivisionSupportedCl(input0, input1, output, reasonIfUnsupported);
132}
133
134bool ClLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
135 const FakeQuantizationDescriptor& descriptor,
136 Optional<std::string&> reasonIfUnsupported) const
137{
138 return armnn::IsFakeQuantizationSupportedCl(input, descriptor, reasonIfUnsupported);
139}
140
141bool ClLayerSupport::IsFloorSupported(const TensorInfo& input,
142 const TensorInfo& output,
143 Optional<std::string&> reasonIfUnsupported) const
144{
145 return armnn::IsFloorSupportedCl(input, output, reasonIfUnsupported);
146}
147
148bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
149 const TensorInfo& output,
150 const TensorInfo& weights,
151 const TensorInfo& biases,
152 const FullyConnectedDescriptor& descriptor,
153 Optional<std::string&> reasonIfUnsupported) const
154{
155 return armnn::IsFullyConnectedSupportedCl(input,
156 output,
157 weights,
158 biases,
159 descriptor,
160 reasonIfUnsupported);
161}
162
163bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
164 Optional<std::string&> reasonIfUnsupported) const
165{
166 return armnn::IsInputSupportedCl(input, reasonIfUnsupported);
167}
168
169bool ClLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
170 const TensorInfo& output,
171 const L2NormalizationDescriptor& descriptor,
172 Optional<std::string&> reasonIfUnsupported) const
173{
174 return armnn::IsL2NormalizationSupportedCl(input, output, descriptor, reasonIfUnsupported);
175}
176
177bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
178 const TensorInfo& outputStateIn,
179 const TensorInfo& cellStateIn,
180 const TensorInfo& scratchBuffer,
181 const TensorInfo& outputStateOut,
182 const TensorInfo& cellStateOut,
183 const TensorInfo& output,
184 const LstmDescriptor& descriptor,
185 const TensorInfo& inputToForgetWeights,
186 const TensorInfo& inputToCellWeights,
187 const TensorInfo& inputToOutputWeights,
188 const TensorInfo& recurrentToForgetWeights,
189 const TensorInfo& recurrentToCellWeights,
190 const TensorInfo& recurrentToOutputWeights,
191 const TensorInfo& forgetGateBias,
192 const TensorInfo& cellBias,
193 const TensorInfo& outputGateBias,
194 const TensorInfo* inputToInputWeights,
195 const TensorInfo* recurrentToInputWeights,
196 const TensorInfo* cellToInputWeights,
197 const TensorInfo* inputGateBias,
198 const TensorInfo* projectionWeights,
199 const TensorInfo* projectionBias,
200 const TensorInfo* cellToForgetWeights,
201 const TensorInfo* cellToOutputWeights,
202 Optional<std::string&> reasonIfUnsupported) const
203{
204 return armnn::IsLstmSupportedCl(input,
205 outputStateIn,
206 cellStateIn,
207 scratchBuffer,
208 outputStateOut,
209 cellStateOut,
210 output,
211 descriptor,
212 inputToForgetWeights,
213 inputToCellWeights,
214 inputToOutputWeights,
215 recurrentToForgetWeights,
216 recurrentToCellWeights,
217 recurrentToOutputWeights,
218 forgetGateBias,
219 cellBias,
220 outputGateBias,
221 inputToInputWeights,
222 recurrentToInputWeights,
223 cellToInputWeights,
224 inputGateBias,
225 projectionWeights,
226 projectionBias,
227 cellToForgetWeights,
228 cellToOutputWeights,
229 reasonIfUnsupported);
230}
231
232bool ClLayerSupport::IsMeanSupported(const TensorInfo& input,
233 const TensorInfo& output,
234 const MeanDescriptor& descriptor,
235 Optional<std::string&> reasonIfUnsupported) const
236{
237 return armnn::IsMeanSupportedCl(input, output, descriptor,reasonIfUnsupported);
238}
239
240bool ClLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
241 const OriginsDescriptor& descriptor,
242 Optional<std::string&> reasonIfUnsupported) const
243{
244 return armnn::IsMergerSupportedCl(inputs, descriptor, reasonIfUnsupported);
245}
246
247bool ClLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
248 const TensorInfo& input1,
249 const TensorInfo& output,
250 Optional<std::string&> reasonIfUnsupported) const
251{
252 return armnn::IsMultiplicationSupportedCl(input0, input1, output, reasonIfUnsupported);
253}
254
255bool ClLayerSupport::IsNormalizationSupported(const TensorInfo& input,
256 const TensorInfo& output,
257 const NormalizationDescriptor& descriptor,
258 Optional<std::string&> reasonIfUnsupported) const
259{
260 return armnn::IsNormalizationSupportedCl(input,
261 output,
262 descriptor,
263 reasonIfUnsupported);
264}
265
266bool ClLayerSupport::IsOutputSupported(const TensorInfo& output,
267 Optional<std::string&> reasonIfUnsupported) const
268{
269 return armnn::IsOutputSupportedCl(output, reasonIfUnsupported);
270}
271
272bool ClLayerSupport::IsPadSupported(const TensorInfo& input,
273 const TensorInfo& output,
274 const PadDescriptor& descriptor,
275 Optional<std::string&> reasonIfUnsupported) const
276{
277 return armnn::IsPadSupportedCl(input, output, descriptor, reasonIfUnsupported);
278}
279
280bool ClLayerSupport::IsPermuteSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const PermuteDescriptor& descriptor,
283 Optional<std::string&> reasonIfUnsupported) const
284{
285 return armnn::IsPermuteSupportedCl(input, output, descriptor, reasonIfUnsupported);
286}
287
288bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
289 const TensorInfo& output,
290 const Pooling2dDescriptor& descriptor,
291 Optional<std::string&> reasonIfUnsupported) const
292{
293 return armnn::IsPooling2dSupportedCl(input, output, descriptor, reasonIfUnsupported);
294}
295
296bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
297 Optional<std::string&> reasonIfUnsupported) const
298{
299 return armnn::IsReshapeSupportedCl(input, reasonIfUnsupported);
300}
301
302bool ClLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
303 Optional<std::string&> reasonIfUnsupported) const
304{
305 return armnn::IsResizeBilinearSupportedCl(input, reasonIfUnsupported);
306}
307
308bool ClLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
309 const TensorInfo& output,
310 const SoftmaxDescriptor& descriptor,
311 Optional<std::string&> reasonIfUnsupported) const
312{
313 return armnn::IsSoftmaxSupportedCl(input, output, descriptor, reasonIfUnsupported);
314}
315
316bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
317 const ViewsDescriptor& descriptor,
318 Optional<std::string&> reasonIfUnsupported) const
319{
320 return armnn::IsSplitterSupportedCl(input, descriptor, reasonIfUnsupported);
321}
322
323bool ClLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
324 const TensorInfo& input1,
325 const TensorInfo& output,
326 Optional<std::string&> reasonIfUnsupported) const
327{
328 return armnn::IsSubtractionSupportedCl(input0, input1, output, reasonIfUnsupported);
329}
330
331//
332// Implementation functions
333//
334// TODO: Functions kept for backward compatibility. Remove redundant functions
335// once transition to plugable backends is complete.
336
telsoa014fcda012018-03-09 14:13:49 +0000337namespace
338{
339template<unsigned int FilterSize>
340bool IsMatchingSize2d(const TensorInfo& weightInfo)
341{
telsoa01c577f2c2018-08-31 09:22:23 +0100342 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +0000343 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
344}
345
346template<uint32_t ValidStride>
347bool IsMatchingStride(uint32_t actualStride)
348{
349 return ValidStride == actualStride;
350}
351
352template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
353bool IsMatchingStride(uint32_t actualStride)
354{
355 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
356};
357
arovir01085f0a42018-10-08 14:48:19 +0100358bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000359{
360#if ARMCOMPUTECL_ENABLED
361 return true;
362#else
arovir01085f0a42018-10-08 14:48:19 +0100363 if (reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000364 {
arovir01085f0a42018-10-08 14:48:19 +0100365 reasonIfUnsupported.value() = "The armnn library has been built without CL support";
telsoa014fcda012018-03-09 14:13:49 +0000366 }
367 return false;
368#endif
369}
370
371#if ARMCOMPUTECL_ENABLED
372#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
373#else
374#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
375#endif
376
377#if ARMCOMPUTECL_ENABLED
378template<class FuncType, class... Args>
arovir01085f0a42018-10-08 14:48:19 +0100379inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
telsoa014fcda012018-03-09 14:13:49 +0000380{
381 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
382 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
383 if (!supported && reasonIfUnsupported)
384 {
arovir01085f0a42018-10-08 14:48:19 +0100385 reasonIfUnsupported.value() = aclStatus.error_description();
telsoa014fcda012018-03-09 14:13:49 +0000386 }
387 return supported;
388}
389
390#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
391 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
392#else
393#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
394 return IsClBackendSupported(reasonIfUnsupported);
395#endif
396
397} //namespace
398
telsoa01c577f2c2018-08-31 09:22:23 +0100399template<typename FloatFunc, typename Uint8Func, typename ... Params>
arovir01085f0a42018-10-08 14:48:19 +0100400bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000401 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100402 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000403 Uint8Func uint8FuncPtr,
404 Params&&... params)
405{
406 return IsClBackendSupported(reasonIfUnsupported) &&
407 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
408 dataType,
409 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100410 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000411 uint8FuncPtr,
412 std::forward<Params>(params)...);
413}
414
415bool IsActivationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100416 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000417 const ActivationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100418 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000419{
telsoa01c577f2c2018-08-31 09:22:23 +0100420 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
421 reasonIfUnsupported,
422 input,
423 output,
424 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000425}
426
427bool IsAdditionSupportedCl(const TensorInfo& input0,
428 const TensorInfo& input1,
429 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100430 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000431{
arovir01085f0a42018-10-08 14:48:19 +0100432 FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
433 reasonIfUnsupported,
434 input0,
435 input1,
436 output);
telsoa014fcda012018-03-09 14:13:49 +0000437}
438
439bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100440 const TensorInfo& output,
441 const TensorInfo& mean,
442 const TensorInfo& var,
443 const TensorInfo& beta,
444 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +0000445 const BatchNormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100446 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000447{
telsoa01c577f2c2018-08-31 09:22:23 +0100448 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
449 reasonIfUnsupported,
450 input,
451 output,
452 mean,
453 var,
454 beta,
455 gamma,
456 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000457}
458
459bool IsConstantSupportedCl(const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100460 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000461{
462 return IsSupportedForDataTypeCl(reasonIfUnsupported,
463 output.GetDataType(),
464 &TrueFunc<>,
465 &FalseFuncU8<>);
466}
467
468bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc)
469{
470 bool isSupported = false;
471
472 bool strideXIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideX);
473 bool strideXIsThree = IsMatchingStride<3>(desc.m_StrideX);
474
475 bool strideYIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideY);
476 bool strideYIsThree = IsMatchingStride<3>(desc.m_StrideY);
477
478 bool strideIsOneOrTwo = strideXIsOneOrTwo && strideYIsOneOrTwo;
479 bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
480
telsoa01c577f2c2018-08-31 09:22:23 +0100481 // 1x1 convolution with strides of 1,2,3.
telsoa014fcda012018-03-09 14:13:49 +0000482 isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
483
telsoa01c577f2c2018-08-31 09:22:23 +0100484 // 3x3 convolution with strides of 1,2.
telsoa014fcda012018-03-09 14:13:49 +0000485 isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
486
487 // 5x5 convolution with strides of 1,2
488 isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
489
telsoa01c577f2c2018-08-31 09:22:23 +0100490 //Fall back to normal convolution for the asymmetric padding case.
telsoa014fcda012018-03-09 14:13:49 +0000491 if (desc.m_PadLeft != desc.m_PadRight ||
492 desc.m_PadTop != desc.m_PadBottom)
493 {
telsoa01c577f2c2018-08-31 09:22:23 +0100494 //Direct convolution does not support asymmetric padding yet.
telsoa014fcda012018-03-09 14:13:49 +0000495 isSupported = false;
496 }
497
498 return isSupported;
499}
500
arovir01085f0a42018-10-08 14:48:19 +0100501bool IsDirectConvolution2dParamsSupportedCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000502 const Convolution2dDescriptor& parameters,
503 const TensorInfo& weightInfo)
504{
arovir01085f0a42018-10-08 14:48:19 +0100505 ignore_unused(reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000506 return IsClDirectConvolution2dSupported(weightInfo, parameters);
507}
508
509bool IsConvolution2dSupportedCl(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +0100510 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000511 const Convolution2dDescriptor& descriptor,
512 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100513 const Optional<TensorInfo>& biases,
arovir01085f0a42018-10-08 14:48:19 +0100514 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000515{
surmeh013537c2c2018-05-18 16:31:43 +0100516 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
517 reasonIfUnsupported,
518 input,
519 output,
520 descriptor,
521 weights,
522 biases);
telsoa014fcda012018-03-09 14:13:49 +0000523}
524
525bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100526 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000527 const DepthwiseConvolution2dDescriptor& descriptor,
528 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100529 const Optional<TensorInfo>& biases,
arovir01085f0a42018-10-08 14:48:19 +0100530 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000531{
telsoa01c577f2c2018-08-31 09:22:23 +0100532 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
533 reasonIfUnsupported,
534 input,
535 output,
536 descriptor,
537 weights,
538 biases);
telsoa014fcda012018-03-09 14:13:49 +0000539}
540
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100541bool IsDivisionSupportedCl(const TensorInfo& input0,
542 const TensorInfo& input1,
543 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100544 Optional<std::string&> reasonIfUnsupported)
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100545{
546 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
547 reasonIfUnsupported,
548 input0,
549 input1,
550 output);
551}
552
David Beckc2044fe2018-09-05 15:00:38 +0100553bool IsSubtractionSupportedCl(const TensorInfo& input0,
554 const TensorInfo& input1,
555 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100556 Optional<std::string&> reasonIfUnsupported)
David Beckc2044fe2018-09-05 15:00:38 +0100557{
arovir01085f0a42018-10-08 14:48:19 +0100558
559 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
560 reasonIfUnsupported,
561 input0,
562 input1,
563 output);
David Beckc2044fe2018-09-05 15:00:38 +0100564}
565
telsoa014fcda012018-03-09 14:13:49 +0000566bool IsFullyConnectedSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100567 const TensorInfo& output,
568 const TensorInfo& weights,
569 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000570 const FullyConnectedDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100571 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000572{
telsoa01c577f2c2018-08-31 09:22:23 +0100573 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
574 reasonIfUnsupported,
575 input,
576 output,
577 weights,
578 biases,
579 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000580}
581
582bool IsInputSupportedCl(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100583 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000584{
585 return IsSupportedForDataTypeCl(reasonIfUnsupported,
586 input.GetDataType(),
587 &TrueFunc<>,
588 &TrueFunc<>);
589}
590
591bool IsL2NormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100592 const TensorInfo& output,
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100593 const L2NormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100594 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000595{
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100596 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000597}
598
599bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
600 const OriginsDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100601 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000602{
603 ignore_unused(descriptor);
604 return IsSupportedForDataTypeCl(reasonIfUnsupported,
605 inputs[0]->GetDataType(),
606 &TrueFunc<>,
607 &FalseFuncU8<>);
608}
609
610bool IsMultiplicationSupportedCl(const TensorInfo& input0,
611 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100612 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100613 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000614{
telsoa01c577f2c2018-08-31 09:22:23 +0100615 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
616 reasonIfUnsupported,
617 input0,
618 input1,
619 output);
telsoa014fcda012018-03-09 14:13:49 +0000620}
621
622bool IsNormalizationSupportedCl(const TensorInfo& input,
623 const TensorInfo& output,
624 const NormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100625 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000626{
627 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
628}
629
630bool IsOutputSupportedCl(const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100631 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000632{
633 return IsSupportedForDataTypeCl(reasonIfUnsupported,
634 output.GetDataType(),
635 &TrueFunc<>,
636 &TrueFunc<>);
637}
638
639bool IsPermuteSupportedCl(const TensorInfo& input,
640 const TensorInfo& output,
641 const PermuteDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100642 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000643{
644 ignore_unused(input);
645 ignore_unused(output);
646 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
647}
648
649bool IsPooling2dSupportedCl(const TensorInfo& input,
650 const TensorInfo& output,
651 const Pooling2dDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100652 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000653{
654 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
655}
656
657bool IsResizeBilinearSupportedCl(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100658 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000659{
660 return IsSupportedForDataTypeCl(reasonIfUnsupported,
661 input.GetDataType(),
662 &TrueFunc<>,
663 &FalseFuncU8<>);
664}
665
666bool IsSoftmaxSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100667 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000668 const SoftmaxDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100669 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000670{
671 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100672 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
telsoa014fcda012018-03-09 14:13:49 +0000673}
674
675bool IsSplitterSupportedCl(const TensorInfo& input,
676 const ViewsDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100677 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000678{
679 ignore_unused(descriptor);
680 return IsSupportedForDataTypeCl(reasonIfUnsupported,
681 input.GetDataType(),
682 &TrueFunc<>,
683 &TrueFunc<>);
684}
685
686bool IsFakeQuantizationSupportedCl(const TensorInfo& input,
687 const FakeQuantizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100688 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000689{
690 ignore_unused(input);
691 ignore_unused(descriptor);
arovir01085f0a42018-10-08 14:48:19 +0100692 ignore_unused(reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000693 return false;
694}
695
696bool IsReshapeSupportedCl(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100697 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000698{
699 ignore_unused(input);
arovir01085f0a42018-10-08 14:48:19 +0100700 ignore_unused(reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000701 return true;
702}
703
704bool IsFloorSupportedCl(const TensorInfo& input,
705 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100706 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000707{
708 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100709 return IsClBackendSupported(reasonIfUnsupported) &&
710 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
711 input.GetDataType(),
712 &FalseFuncF16<>,
713 &TrueFunc<>,
714 &FalseFuncU8<>);
715}
716
arovir01085f0a42018-10-08 14:48:19 +0100717bool IsLstmSupportedCl(const TensorInfo& input,
718 const TensorInfo& outputStateIn,
719 const TensorInfo& cellStateIn,
720 const TensorInfo& scratchBuffer,
721 const TensorInfo& outputStateOut,
722 const TensorInfo& cellStateOut,
723 const TensorInfo& output,
724 const LstmDescriptor& descriptor,
725 const TensorInfo& inputToForgetWeights,
726 const TensorInfo& inputToCellWeights,
727 const TensorInfo& inputToOutputWeights,
728 const TensorInfo& recurrentToForgetWeights,
729 const TensorInfo& recurrentToCellWeights,
730 const TensorInfo& recurrentToOutputWeights,
731 const TensorInfo& forgetGateBias,
732 const TensorInfo& cellBias,
733 const TensorInfo& outputGateBias,
734 const TensorInfo* inputToInputWeights,
735 const TensorInfo* recurrentToInputWeights,
736 const TensorInfo* cellToInputWeights,
737 const TensorInfo* inputGateBias,
738 const TensorInfo* projectionWeights,
739 const TensorInfo* projectionBias,
740 const TensorInfo* cellToForgetWeights,
741 const TensorInfo* cellToOutputWeights,
742 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100743{
arovir01085f0a42018-10-08 14:48:19 +0100744 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
745 reasonIfUnsupported,
746 input,
747 outputStateIn,
748 cellStateIn,
749 scratchBuffer,
750 outputStateOut,
751 cellStateOut,
752 output,
753 descriptor,
754 inputToForgetWeights,
755 inputToCellWeights,
756 inputToOutputWeights,
757 recurrentToForgetWeights,
758 recurrentToCellWeights,
759 recurrentToOutputWeights,
760 forgetGateBias,
761 cellBias,
762 outputGateBias,
763 inputToInputWeights,
764 recurrentToInputWeights,
765 cellToInputWeights,
766 inputGateBias,
767 projectionWeights,
768 projectionBias,
769 cellToForgetWeights,
770 cellToOutputWeights);
telsoa01c577f2c2018-08-31 09:22:23 +0100771}
772
773bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input,
774 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100775 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100776{
777 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
778 reasonIfUnsupported,
779 input,
arovir01085f0a42018-10-08 14:48:19 +0100780 output);
telsoa01c577f2c2018-08-31 09:22:23 +0100781}
782
783bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input,
784 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100785 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100786{
787 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
788 reasonIfUnsupported,
789 input,
arovir01085f0a42018-10-08 14:48:19 +0100790 output);
telsoa014fcda012018-03-09 14:13:49 +0000791}
792
narpra0132b90462018-09-13 11:07:48 +0100793bool IsMeanSupportedCl(const TensorInfo& input,
794 const TensorInfo& output,
795 const MeanDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100796 Optional<std::string&> reasonIfUnsupported)
narpra0132b90462018-09-13 11:07:48 +0100797{
arovir01085f0a42018-10-08 14:48:19 +0100798 ignore_unused(input);
799 ignore_unused(output);
800 ignore_unused(descriptor);
801 ignore_unused(reasonIfUnsupported);
narpra0132b90462018-09-13 11:07:48 +0100802 return false;
803}
804
arovir01085f0a42018-10-08 14:48:19 +0100805bool IsPadSupportedCl(const TensorInfo& input,
806 const TensorInfo& output,
807 const PadDescriptor& descriptor,
808 Optional<std::string&> reasonIfUnsupported)
809{
810 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate,
811 reasonIfUnsupported,
812 input,
813 output,
814 descriptor);
815}
816
telsoa014fcda012018-03-09 14:13:49 +0000817}