blob: c4d45fe8ef4008b5571e038b533a3f41367b79be [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "ClLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "ClBackendId.hpp"
arovir017c22c702018-10-09 11:16:46 +01008
David Beck3cc9a622018-10-12 10:38:31 +01009#include <armnn/Descriptors.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
David Beck111b5d92018-11-12 14:59:37 +000013#include <backendsCommon/BackendRegistry.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010014
telsoa014fcda012018-03-09 14:13:49 +000015#include <boost/core/ignore_unused.hpp>
16
17#ifdef ARMCOMPUTECL_ENABLED
David Beckac42efd2018-09-26 17:41:13 +010018#include "workloads/ClAdditionWorkload.hpp"
Nattapat Chaimanowonge06757e2018-10-11 15:39:18 +010019#include "workloads/ClActivationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010020#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
Mike Kelly831faed2018-11-28 11:52:08 +000021#include "workloads/ClBatchToSpaceNdWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010022#include "workloads/ClConvertFp16ToFp32Workload.hpp"
23#include "workloads/ClConvertFp32ToFp16Workload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010024#include "workloads/ClConvolution2dWorkload.hpp"
Matthew Benthamd8777392018-10-08 09:38:55 +010025#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010026#include "workloads/ClDivisionFloatWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010027#include "workloads/ClFullyConnectedWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010028#include "workloads/ClL2NormalizationFloatWorkload.hpp"
29#include "workloads/ClLstmFloatWorkload.hpp"
keidav01a959ee52018-12-19 10:04:58 +000030#include "workloads/ClMaximumWorkload.hpp"
Matteo Martincigh28dcab62018-10-19 16:40:03 +010031#include "workloads/ClMeanWorkload.hpp"
Nikhil Raj8599a412018-11-19 14:51:07 +000032#include "workloads/ClMergerWorkload.hpp"
saoste019292aa32019-01-08 13:55:59 +000033#include "workloads/ClMinimumWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010034#include "workloads/ClMultiplicationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010035#include "workloads/ClNormalizationFloatWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010036#include "workloads/ClPadWorkload.hpp"
37#include "workloads/ClPermuteWorkload.hpp"
Nattapat Chaimanowongac9e0962018-10-10 17:18:35 +010038#include "workloads/ClPooling2dWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010039#include "workloads/ClSoftmaxBaseWorkload.hpp"
Sadik Armaganf4464322018-12-20 16:19:12 +000040#include "workloads/ClSpaceToBatchNdWorkload.hpp"
keidav01d74dc912018-12-10 18:16:07 +000041#include "workloads/ClStridedSliceWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010042#include "workloads/ClSubtractionWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000043#endif
44
45using namespace boost;
46
47namespace armnn
48{
arovir017c22c702018-10-09 11:16:46 +010049
telsoa014fcda012018-03-09 14:13:49 +000050namespace
51{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010052
telsoa014fcda012018-03-09 14:13:49 +000053template<unsigned int FilterSize>
54bool IsMatchingSize2d(const TensorInfo& weightInfo)
55{
telsoa01c577f2c2018-08-31 09:22:23 +010056 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +000057 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
58}
59
60template<uint32_t ValidStride>
61bool IsMatchingStride(uint32_t actualStride)
62{
63 return ValidStride == actualStride;
64}
65
66template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
67bool IsMatchingStride(uint32_t actualStride)
68{
69 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +010070}
telsoa014fcda012018-03-09 14:13:49 +000071
arovir01085f0a42018-10-08 14:48:19 +010072bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000073{
74#if ARMCOMPUTECL_ENABLED
75 return true;
76#else
arovir01085f0a42018-10-08 14:48:19 +010077 if (reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000078 {
arovir01085f0a42018-10-08 14:48:19 +010079 reasonIfUnsupported.value() = "The armnn library has been built without CL support";
telsoa014fcda012018-03-09 14:13:49 +000080 }
81 return false;
82#endif
83}
84
85#if ARMCOMPUTECL_ENABLED
86#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
87#else
88#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
89#endif
90
91#if ARMCOMPUTECL_ENABLED
92template<class FuncType, class... Args>
arovir01085f0a42018-10-08 14:48:19 +010093inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
telsoa014fcda012018-03-09 14:13:49 +000094{
95 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
96 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
97 if (!supported && reasonIfUnsupported)
98 {
arovir01085f0a42018-10-08 14:48:19 +010099 reasonIfUnsupported.value() = aclStatus.error_description();
telsoa014fcda012018-03-09 14:13:49 +0000100 }
101 return supported;
102}
103
104#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
105 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
106#else
107#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
108 return IsClBackendSupported(reasonIfUnsupported);
109#endif
110
telsoa01c577f2c2018-08-31 09:22:23 +0100111template<typename FloatFunc, typename Uint8Func, typename ... Params>
arovir01085f0a42018-10-08 14:48:19 +0100112bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000113 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100114 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000115 Uint8Func uint8FuncPtr,
116 Params&&... params)
117{
118 return IsClBackendSupported(reasonIfUnsupported) &&
119 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
120 dataType,
121 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100122 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000123 uint8FuncPtr,
124 std::forward<Params>(params)...);
125}
126
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100127} // anonymous namespace
128
129bool ClLayerSupport::IsActivationSupported(const TensorInfo& input,
130 const TensorInfo& output,
131 const ActivationDescriptor& descriptor,
132 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000133{
telsoa01c577f2c2018-08-31 09:22:23 +0100134 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
135 reasonIfUnsupported,
136 input,
137 output,
138 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000139}
140
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100141bool ClLayerSupport::IsAdditionSupported(const TensorInfo& input0,
142 const TensorInfo& input1,
143 const TensorInfo& output,
144 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000145{
arovir01085f0a42018-10-08 14:48:19 +0100146 FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
147 reasonIfUnsupported,
148 input0,
149 input1,
150 output);
telsoa014fcda012018-03-09 14:13:49 +0000151}
152
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100153bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
154 const TensorInfo& output,
155 const TensorInfo& mean,
156 const TensorInfo& var,
157 const TensorInfo& beta,
158 const TensorInfo& gamma,
159 const BatchNormalizationDescriptor& descriptor,
160 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000161{
telsoa01c577f2c2018-08-31 09:22:23 +0100162 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
163 reasonIfUnsupported,
164 input,
165 output,
166 mean,
167 var,
168 beta,
169 gamma,
170 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000171}
172
Mike Kelly831faed2018-11-28 11:52:08 +0000173bool ClLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
174 const TensorInfo& output,
175 const BatchToSpaceNdDescriptor& descriptor,
176 Optional<std::string&> reasonIfUnsupported) const
177{
178 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchToSpaceNdWorkloadValidate,
179 reasonIfUnsupported,
180 input,
181 output,
182 descriptor);
183}
184
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100185bool ClLayerSupport::IsConstantSupported(const TensorInfo& output,
186 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000187{
188 return IsSupportedForDataTypeCl(reasonIfUnsupported,
189 output.GetDataType(),
190 &TrueFunc<>,
191 &FalseFuncU8<>);
192}
193
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100194bool ClLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
195 const TensorInfo& output,
196 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000197{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100198 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
199 reasonIfUnsupported,
200 input,
201 output);
telsoa014fcda012018-03-09 14:13:49 +0000202}
203
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100204bool ClLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
205 const TensorInfo& output,
206 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000207{
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100208 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
209 reasonIfUnsupported,
210 input,
211 output);
telsoa014fcda012018-03-09 14:13:49 +0000212}
213
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100214bool ClLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
215 const TensorInfo& output,
216 const Convolution2dDescriptor& descriptor,
217 const TensorInfo& weights,
218 const Optional<TensorInfo>& biases,
219 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000220{
surmeh013537c2c2018-05-18 16:31:43 +0100221 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
222 reasonIfUnsupported,
223 input,
224 output,
225 descriptor,
226 weights,
227 biases);
telsoa014fcda012018-03-09 14:13:49 +0000228}
229
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100230bool ClLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
231 const TensorInfo& output,
232 const DepthwiseConvolution2dDescriptor& descriptor,
233 const TensorInfo& weights,
234 const Optional<TensorInfo>& biases,
235 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000236{
telsoa01c577f2c2018-08-31 09:22:23 +0100237 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
238 reasonIfUnsupported,
239 input,
240 output,
241 descriptor,
242 weights,
243 biases);
telsoa014fcda012018-03-09 14:13:49 +0000244}
245
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100246bool ClLayerSupport::IsDivisionSupported(const TensorInfo& input0,
247 const TensorInfo& input1,
248 const TensorInfo& output,
249 Optional<std::string&> reasonIfUnsupported) const
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100250{
251 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
252 reasonIfUnsupported,
253 input0,
254 input1,
255 output);
256}
257
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100258bool ClLayerSupport::IsFloorSupported(const TensorInfo& input,
259 const TensorInfo& output,
260 Optional<std::string&> reasonIfUnsupported) const
telsoa014fcda012018-03-09 14:13:49 +0000261{
262 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100263 return IsClBackendSupported(reasonIfUnsupported) &&
264 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
265 input.GetDataType(),
266 &FalseFuncF16<>,
267 &TrueFunc<>,
268 &FalseFuncU8<>);
269}
270
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100271bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
272 const TensorInfo& output,
273 const TensorInfo& weights,
274 const TensorInfo& biases,
275 const FullyConnectedDescriptor& descriptor,
276 Optional<std::string&> reasonIfUnsupported) const
277{
278 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
279 reasonIfUnsupported,
280 input,
281 output,
282 weights,
283 biases,
284 descriptor);
285}
286
287bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
288 Optional<std::string&> reasonIfUnsupported) const
289{
290 return IsSupportedForDataTypeCl(reasonIfUnsupported,
291 input.GetDataType(),
292 &TrueFunc<>,
293 &TrueFunc<>);
294}
295
296bool ClLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
297 const TensorInfo& output,
298 const L2NormalizationDescriptor& descriptor,
299 Optional<std::string&> reasonIfUnsupported) const
300{
301 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate,
302 reasonIfUnsupported,
303 input,
304 output,
305 descriptor);
306}
307
308bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
309 const TensorInfo& outputStateIn,
310 const TensorInfo& cellStateIn,
311 const TensorInfo& scratchBuffer,
312 const TensorInfo& outputStateOut,
313 const TensorInfo& cellStateOut,
314 const TensorInfo& output,
315 const LstmDescriptor& descriptor,
316 const TensorInfo& inputToForgetWeights,
317 const TensorInfo& inputToCellWeights,
318 const TensorInfo& inputToOutputWeights,
319 const TensorInfo& recurrentToForgetWeights,
320 const TensorInfo& recurrentToCellWeights,
321 const TensorInfo& recurrentToOutputWeights,
322 const TensorInfo& forgetGateBias,
323 const TensorInfo& cellBias,
324 const TensorInfo& outputGateBias,
325 const TensorInfo* inputToInputWeights,
326 const TensorInfo* recurrentToInputWeights,
327 const TensorInfo* cellToInputWeights,
328 const TensorInfo* inputGateBias,
329 const TensorInfo* projectionWeights,
330 const TensorInfo* projectionBias,
331 const TensorInfo* cellToForgetWeights,
332 const TensorInfo* cellToOutputWeights,
333 Optional<std::string&> reasonIfUnsupported) const
telsoa01c577f2c2018-08-31 09:22:23 +0100334{
arovir01085f0a42018-10-08 14:48:19 +0100335 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
336 reasonIfUnsupported,
337 input,
338 outputStateIn,
339 cellStateIn,
340 scratchBuffer,
341 outputStateOut,
342 cellStateOut,
343 output,
344 descriptor,
345 inputToForgetWeights,
346 inputToCellWeights,
347 inputToOutputWeights,
348 recurrentToForgetWeights,
349 recurrentToCellWeights,
350 recurrentToOutputWeights,
351 forgetGateBias,
352 cellBias,
353 outputGateBias,
354 inputToInputWeights,
355 recurrentToInputWeights,
356 cellToInputWeights,
357 inputGateBias,
358 projectionWeights,
359 projectionBias,
360 cellToForgetWeights,
361 cellToOutputWeights);
telsoa01c577f2c2018-08-31 09:22:23 +0100362}
363
keidav01a959ee52018-12-19 10:04:58 +0000364bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0,
365 const TensorInfo& input1,
366 const TensorInfo& output,
367 Optional<std::string&> reasonIfUnsupported) const
368{
369 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMaximumWorkloadValidate,
370 reasonIfUnsupported,
371 input0,
372 input1,
373 output);
374}
375
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100376bool ClLayerSupport::IsMeanSupported(const TensorInfo& input,
377 const TensorInfo& output,
378 const MeanDescriptor& descriptor,
379 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100380{
Matteo Martincigh28dcab62018-10-19 16:40:03 +0100381 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMeanValidate,
382 reasonIfUnsupported,
383 input,
384 output,
385 descriptor);
narpra0132b90462018-09-13 11:07:48 +0100386}
387
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100388bool ClLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000389 const TensorInfo& output,
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100390 const OriginsDescriptor& descriptor,
391 Optional<std::string&> reasonIfUnsupported) const
392{
Nikhil Raj8599a412018-11-19 14:51:07 +0000393 if(descriptor.GetNumDimensions() - descriptor.GetConcatAxis() == 1)
394 {
395 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMergerWorkloadValidate,
396 reasonIfUnsupported,
397 inputs,
398 output,
399 descriptor);
400 }
401 else
402 {
403 return IsSupportedForDataTypeCl(reasonIfUnsupported,
404 inputs[0]->GetDataType(),
405 &TrueFunc<>,
narpra0163b08822018-11-20 11:29:12 +0000406 &TrueFunc<>);
Nikhil Raj8599a412018-11-19 14:51:07 +0000407 }
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100408}
409
saoste019292aa32019-01-08 13:55:59 +0000410bool ClLayerSupport::IsMinimumSupported(const TensorInfo& input0,
411 const TensorInfo& input1,
412 const TensorInfo& output,
413 Optional<std::string&> reasonIfUnsupported) const
414{
415 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMinimumWorkloadValidate,
416 reasonIfUnsupported,
417 input0,
418 input1,
419 output);
420}
421
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100422bool ClLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
423 const TensorInfo& input1,
424 const TensorInfo& output,
425 Optional<std::string&> reasonIfUnsupported) const
426{
427 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
428 reasonIfUnsupported,
429 input0,
430 input1,
431 output);
432}
433
434bool ClLayerSupport::IsNormalizationSupported(const TensorInfo& input,
435 const TensorInfo& output,
436 const NormalizationDescriptor& descriptor,
437 Optional<std::string&> reasonIfUnsupported) const
438{
439 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
440}
441
442bool ClLayerSupport::IsOutputSupported(const TensorInfo& output,
443 Optional<std::string&> reasonIfUnsupported) const
444{
445 return IsSupportedForDataTypeCl(reasonIfUnsupported,
446 output.GetDataType(),
447 &TrueFunc<>,
448 &TrueFunc<>);
449}
450
451bool ClLayerSupport::IsPadSupported(const TensorInfo& input,
452 const TensorInfo& output,
453 const PadDescriptor& descriptor,
454 Optional<std::string&> reasonIfUnsupported) const
arovir01085f0a42018-10-08 14:48:19 +0100455{
456 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate,
457 reasonIfUnsupported,
458 input,
459 output,
460 descriptor);
461}
462
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100463bool ClLayerSupport::IsPermuteSupported(const TensorInfo& input,
464 const TensorInfo& output,
465 const PermuteDescriptor& descriptor,
466 Optional<std::string&> reasonIfUnsupported) const
467{
468 ignore_unused(input);
469 ignore_unused(output);
470 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000471}
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100472
473bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
474 const TensorInfo& output,
475 const Pooling2dDescriptor& descriptor,
476 Optional<std::string&> reasonIfUnsupported) const
477{
478 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
479}
480
481bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
482 Optional<std::string&> reasonIfUnsupported) const
483{
484 ignore_unused(input);
485 ignore_unused(reasonIfUnsupported);
486 return true;
487}
488
489bool ClLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
490 Optional<std::string&> reasonIfUnsupported) const
491{
492 return IsSupportedForDataTypeCl(reasonIfUnsupported,
493 input.GetDataType(),
494 &TrueFunc<>,
495 &FalseFuncU8<>);
496}
497
498bool ClLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
499 const TensorInfo& output,
500 const SoftmaxDescriptor& descriptor,
501 Optional<std::string&> reasonIfUnsupported) const
502{
503 ignore_unused(descriptor);
504 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
505}
506
Sadik Armaganf4464322018-12-20 16:19:12 +0000507bool ClLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
508 const TensorInfo& output,
509 const SpaceToBatchNdDescriptor& descriptor,
510 Optional<std::string&> reasonIfUnsupported) const
511{
512 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSpaceToBatchNdWorkloadValidate,
513 reasonIfUnsupported,
514 input,
515 output,
516 descriptor);
517}
518
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100519bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
520 const ViewsDescriptor& descriptor,
521 Optional<std::string&> reasonIfUnsupported) const
522{
523 ignore_unused(descriptor);
524 return IsSupportedForDataTypeCl(reasonIfUnsupported,
525 input.GetDataType(),
526 &TrueFunc<>,
527 &TrueFunc<>);
528}
529
keidav01d74dc912018-12-10 18:16:07 +0000530bool ClLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
531 const TensorInfo& output,
532 const StridedSliceDescriptor& descriptor,
533 Optional<std::string&> reasonIfUnsupported) const
534{
535 FORWARD_WORKLOAD_VALIDATE_FUNC(ClStridedSliceWorkloadValidate,
536 reasonIfUnsupported,
537 input,
538 output,
539 descriptor);
540}
541
Aron Virginas-Tarbcf9f162018-10-15 11:47:37 +0100542bool ClLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
543 const TensorInfo& input1,
544 const TensorInfo& output,
545 Optional<std::string&> reasonIfUnsupported) const
546{
547 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
548 reasonIfUnsupported,
549 input0,
550 input1,
551 output);
552}
553
554} // namespace armnn