blob: 494b339952dd8591d0ef1c5721517d42867f0b08 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "LayerSupportCommon.hpp"
7
8#include "ClLayerSupport.hpp"
9#include "InternalTypes.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010#include <armnn/Descriptors.hpp>
11#include <armnn/Types.hpp>
12#include <armnn/Tensor.hpp>
13
14#include <boost/core/ignore_unused.hpp>
15
16#ifdef ARMCOMPUTECL_ENABLED
David Beckac42efd2018-09-26 17:41:13 +010017#include "workloads/ClAdditionWorkload.hpp"
18#include "workloads/ClActivationFloatWorkload.hpp"
19#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
20#include "workloads/ClConvertFp16ToFp32Workload.hpp"
21#include "workloads/ClConvertFp32ToFp16Workload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010022#include "workloads/ClConvolution2dWorkload.hpp"
Matthew Benthamd8777392018-10-08 09:38:55 +010023#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010024#include "workloads/ClDivisionFloatWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010025#include "workloads/ClFullyConnectedWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010026#include "workloads/ClL2NormalizationFloatWorkload.hpp"
27#include "workloads/ClLstmFloatWorkload.hpp"
28#include "workloads/ClMultiplicationWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010029#include "workloads/ClNormalizationFloatWorkload.hpp"
arovir01085f0a42018-10-08 14:48:19 +010030#include "workloads/ClPadWorkload.hpp"
31#include "workloads/ClPermuteWorkload.hpp"
32#include "workloads/ClPooling2dBaseWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +010033#include "workloads/ClSoftmaxBaseWorkload.hpp"
34#include "workloads/ClSubtractionWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +000035#endif
36
37using namespace boost;
38
39namespace armnn
40{
41namespace
42{
43template<unsigned int FilterSize>
44bool IsMatchingSize2d(const TensorInfo& weightInfo)
45{
telsoa01c577f2c2018-08-31 09:22:23 +010046 // Width & Height must match.
telsoa014fcda012018-03-09 14:13:49 +000047 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
48}
49
50template<uint32_t ValidStride>
51bool IsMatchingStride(uint32_t actualStride)
52{
53 return ValidStride == actualStride;
54}
55
56template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
57bool IsMatchingStride(uint32_t actualStride)
58{
59 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
60};
61
arovir01085f0a42018-10-08 14:48:19 +010062bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000063{
64#if ARMCOMPUTECL_ENABLED
65 return true;
66#else
arovir01085f0a42018-10-08 14:48:19 +010067 if (reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000068 {
arovir01085f0a42018-10-08 14:48:19 +010069 reasonIfUnsupported.value() = "The armnn library has been built without CL support";
telsoa014fcda012018-03-09 14:13:49 +000070 }
71 return false;
72#endif
73}
74
75#if ARMCOMPUTECL_ENABLED
76#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
77#else
78#define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
79#endif
80
81#if ARMCOMPUTECL_ENABLED
82template<class FuncType, class... Args>
arovir01085f0a42018-10-08 14:48:19 +010083inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
telsoa014fcda012018-03-09 14:13:49 +000084{
85 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
86 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
87 if (!supported && reasonIfUnsupported)
88 {
arovir01085f0a42018-10-08 14:48:19 +010089 reasonIfUnsupported.value() = aclStatus.error_description();
telsoa014fcda012018-03-09 14:13:49 +000090 }
91 return supported;
92}
93
94#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
95 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
96#else
97#define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
98 return IsClBackendSupported(reasonIfUnsupported);
99#endif
100
101} //namespace
102
telsoa01c577f2c2018-08-31 09:22:23 +0100103template<typename FloatFunc, typename Uint8Func, typename ... Params>
arovir01085f0a42018-10-08 14:48:19 +0100104bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000105 DataType dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100106 FloatFunc floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000107 Uint8Func uint8FuncPtr,
108 Params&&... params)
109{
110 return IsClBackendSupported(reasonIfUnsupported) &&
111 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
112 dataType,
113 floatFuncPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100114 floatFuncPtr,
telsoa014fcda012018-03-09 14:13:49 +0000115 uint8FuncPtr,
116 std::forward<Params>(params)...);
117}
118
119bool IsActivationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100120 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000121 const ActivationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100122 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000123{
telsoa01c577f2c2018-08-31 09:22:23 +0100124 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
125 reasonIfUnsupported,
126 input,
127 output,
128 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000129}
130
131bool IsAdditionSupportedCl(const TensorInfo& input0,
132 const TensorInfo& input1,
133 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100134 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000135{
arovir01085f0a42018-10-08 14:48:19 +0100136 FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
137 reasonIfUnsupported,
138 input0,
139 input1,
140 output);
telsoa014fcda012018-03-09 14:13:49 +0000141}
142
143bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100144 const TensorInfo& output,
145 const TensorInfo& mean,
146 const TensorInfo& var,
147 const TensorInfo& beta,
148 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +0000149 const BatchNormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100150 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000151{
telsoa01c577f2c2018-08-31 09:22:23 +0100152 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
153 reasonIfUnsupported,
154 input,
155 output,
156 mean,
157 var,
158 beta,
159 gamma,
160 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000161}
162
163bool IsConstantSupportedCl(const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100164 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000165{
166 return IsSupportedForDataTypeCl(reasonIfUnsupported,
167 output.GetDataType(),
168 &TrueFunc<>,
169 &FalseFuncU8<>);
170}
171
172bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc)
173{
174 bool isSupported = false;
175
176 bool strideXIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideX);
177 bool strideXIsThree = IsMatchingStride<3>(desc.m_StrideX);
178
179 bool strideYIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideY);
180 bool strideYIsThree = IsMatchingStride<3>(desc.m_StrideY);
181
182 bool strideIsOneOrTwo = strideXIsOneOrTwo && strideYIsOneOrTwo;
183 bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
184
telsoa01c577f2c2018-08-31 09:22:23 +0100185 // 1x1 convolution with strides of 1,2,3.
telsoa014fcda012018-03-09 14:13:49 +0000186 isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
187
telsoa01c577f2c2018-08-31 09:22:23 +0100188 // 3x3 convolution with strides of 1,2.
telsoa014fcda012018-03-09 14:13:49 +0000189 isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
190
191 // 5x5 convolution with strides of 1,2
192 isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
193
telsoa01c577f2c2018-08-31 09:22:23 +0100194 //Fall back to normal convolution for the asymmetric padding case.
telsoa014fcda012018-03-09 14:13:49 +0000195 if (desc.m_PadLeft != desc.m_PadRight ||
196 desc.m_PadTop != desc.m_PadBottom)
197 {
telsoa01c577f2c2018-08-31 09:22:23 +0100198 //Direct convolution does not support asymmetric padding yet.
telsoa014fcda012018-03-09 14:13:49 +0000199 isSupported = false;
200 }
201
202 return isSupported;
203}
204
arovir01085f0a42018-10-08 14:48:19 +0100205bool IsDirectConvolution2dParamsSupportedCl(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000206 const Convolution2dDescriptor& parameters,
207 const TensorInfo& weightInfo)
208{
arovir01085f0a42018-10-08 14:48:19 +0100209 ignore_unused(reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000210 return IsClDirectConvolution2dSupported(weightInfo, parameters);
211}
212
213bool IsConvolution2dSupportedCl(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +0100214 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000215 const Convolution2dDescriptor& descriptor,
216 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100217 const Optional<TensorInfo>& biases,
arovir01085f0a42018-10-08 14:48:19 +0100218 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000219{
surmeh013537c2c2018-05-18 16:31:43 +0100220 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
221 reasonIfUnsupported,
222 input,
223 output,
224 descriptor,
225 weights,
226 biases);
telsoa014fcda012018-03-09 14:13:49 +0000227}
228
229bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100230 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000231 const DepthwiseConvolution2dDescriptor& descriptor,
232 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100233 const Optional<TensorInfo>& biases,
arovir01085f0a42018-10-08 14:48:19 +0100234 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000235{
telsoa01c577f2c2018-08-31 09:22:23 +0100236 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
237 reasonIfUnsupported,
238 input,
239 output,
240 descriptor,
241 weights,
242 biases);
telsoa014fcda012018-03-09 14:13:49 +0000243}
244
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100245bool IsDivisionSupportedCl(const TensorInfo& input0,
246 const TensorInfo& input1,
247 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100248 Optional<std::string&> reasonIfUnsupported)
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100249{
250 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
251 reasonIfUnsupported,
252 input0,
253 input1,
254 output);
255}
256
David Beckc2044fe2018-09-05 15:00:38 +0100257bool IsSubtractionSupportedCl(const TensorInfo& input0,
258 const TensorInfo& input1,
259 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100260 Optional<std::string&> reasonIfUnsupported)
David Beckc2044fe2018-09-05 15:00:38 +0100261{
arovir01085f0a42018-10-08 14:48:19 +0100262
263 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
264 reasonIfUnsupported,
265 input0,
266 input1,
267 output);
David Beckc2044fe2018-09-05 15:00:38 +0100268}
269
telsoa014fcda012018-03-09 14:13:49 +0000270bool IsFullyConnectedSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100271 const TensorInfo& output,
272 const TensorInfo& weights,
273 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000274 const FullyConnectedDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100275 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000276{
telsoa01c577f2c2018-08-31 09:22:23 +0100277 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
278 reasonIfUnsupported,
279 input,
280 output,
281 weights,
282 biases,
283 descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000284}
285
286bool IsInputSupportedCl(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100287 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000288{
289 return IsSupportedForDataTypeCl(reasonIfUnsupported,
290 input.GetDataType(),
291 &TrueFunc<>,
292 &TrueFunc<>);
293}
294
295bool IsL2NormalizationSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100296 const TensorInfo& output,
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100297 const L2NormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100298 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000299{
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100300 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000301}
302
303bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
304 const OriginsDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100305 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000306{
307 ignore_unused(descriptor);
308 return IsSupportedForDataTypeCl(reasonIfUnsupported,
309 inputs[0]->GetDataType(),
310 &TrueFunc<>,
311 &FalseFuncU8<>);
312}
313
314bool IsMultiplicationSupportedCl(const TensorInfo& input0,
315 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100316 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100317 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000318{
telsoa01c577f2c2018-08-31 09:22:23 +0100319 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
320 reasonIfUnsupported,
321 input0,
322 input1,
323 output);
telsoa014fcda012018-03-09 14:13:49 +0000324}
325
326bool IsNormalizationSupportedCl(const TensorInfo& input,
327 const TensorInfo& output,
328 const NormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100329 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000330{
331 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
332}
333
334bool IsOutputSupportedCl(const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100335 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000336{
337 return IsSupportedForDataTypeCl(reasonIfUnsupported,
338 output.GetDataType(),
339 &TrueFunc<>,
340 &TrueFunc<>);
341}
342
343bool IsPermuteSupportedCl(const TensorInfo& input,
344 const TensorInfo& output,
345 const PermuteDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100346 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000347{
348 ignore_unused(input);
349 ignore_unused(output);
350 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
351}
352
353bool IsPooling2dSupportedCl(const TensorInfo& input,
354 const TensorInfo& output,
355 const Pooling2dDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100356 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000357{
358 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
359}
360
361bool IsResizeBilinearSupportedCl(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100362 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000363{
364 return IsSupportedForDataTypeCl(reasonIfUnsupported,
365 input.GetDataType(),
366 &TrueFunc<>,
367 &FalseFuncU8<>);
368}
369
370bool IsSoftmaxSupportedCl(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100371 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000372 const SoftmaxDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100373 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000374{
375 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100376 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
telsoa014fcda012018-03-09 14:13:49 +0000377}
378
379bool IsSplitterSupportedCl(const TensorInfo& input,
380 const ViewsDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100381 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000382{
383 ignore_unused(descriptor);
384 return IsSupportedForDataTypeCl(reasonIfUnsupported,
385 input.GetDataType(),
386 &TrueFunc<>,
387 &TrueFunc<>);
388}
389
390bool IsFakeQuantizationSupportedCl(const TensorInfo& input,
391 const FakeQuantizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100392 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000393{
394 ignore_unused(input);
395 ignore_unused(descriptor);
arovir01085f0a42018-10-08 14:48:19 +0100396 ignore_unused(reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000397 return false;
398}
399
400bool IsReshapeSupportedCl(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100401 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000402{
403 ignore_unused(input);
arovir01085f0a42018-10-08 14:48:19 +0100404 ignore_unused(reasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +0000405 return true;
406}
407
408bool IsFloorSupportedCl(const TensorInfo& input,
409 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100410 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000411{
412 ignore_unused(output);
telsoa01c577f2c2018-08-31 09:22:23 +0100413 return IsClBackendSupported(reasonIfUnsupported) &&
414 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
415 input.GetDataType(),
416 &FalseFuncF16<>,
417 &TrueFunc<>,
418 &FalseFuncU8<>);
419}
420
arovir01085f0a42018-10-08 14:48:19 +0100421bool IsLstmSupportedCl(const TensorInfo& input,
422 const TensorInfo& outputStateIn,
423 const TensorInfo& cellStateIn,
424 const TensorInfo& scratchBuffer,
425 const TensorInfo& outputStateOut,
426 const TensorInfo& cellStateOut,
427 const TensorInfo& output,
428 const LstmDescriptor& descriptor,
429 const TensorInfo& inputToForgetWeights,
430 const TensorInfo& inputToCellWeights,
431 const TensorInfo& inputToOutputWeights,
432 const TensorInfo& recurrentToForgetWeights,
433 const TensorInfo& recurrentToCellWeights,
434 const TensorInfo& recurrentToOutputWeights,
435 const TensorInfo& forgetGateBias,
436 const TensorInfo& cellBias,
437 const TensorInfo& outputGateBias,
438 const TensorInfo* inputToInputWeights,
439 const TensorInfo* recurrentToInputWeights,
440 const TensorInfo* cellToInputWeights,
441 const TensorInfo* inputGateBias,
442 const TensorInfo* projectionWeights,
443 const TensorInfo* projectionBias,
444 const TensorInfo* cellToForgetWeights,
445 const TensorInfo* cellToOutputWeights,
446 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100447{
arovir01085f0a42018-10-08 14:48:19 +0100448 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
449 reasonIfUnsupported,
450 input,
451 outputStateIn,
452 cellStateIn,
453 scratchBuffer,
454 outputStateOut,
455 cellStateOut,
456 output,
457 descriptor,
458 inputToForgetWeights,
459 inputToCellWeights,
460 inputToOutputWeights,
461 recurrentToForgetWeights,
462 recurrentToCellWeights,
463 recurrentToOutputWeights,
464 forgetGateBias,
465 cellBias,
466 outputGateBias,
467 inputToInputWeights,
468 recurrentToInputWeights,
469 cellToInputWeights,
470 inputGateBias,
471 projectionWeights,
472 projectionBias,
473 cellToForgetWeights,
474 cellToOutputWeights);
telsoa01c577f2c2018-08-31 09:22:23 +0100475}
476
477bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input,
478 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100479 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100480{
481 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
482 reasonIfUnsupported,
483 input,
arovir01085f0a42018-10-08 14:48:19 +0100484 output);
telsoa01c577f2c2018-08-31 09:22:23 +0100485}
486
487bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input,
488 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100489 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100490{
491 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
492 reasonIfUnsupported,
493 input,
arovir01085f0a42018-10-08 14:48:19 +0100494 output);
telsoa014fcda012018-03-09 14:13:49 +0000495}
496
narpra0132b90462018-09-13 11:07:48 +0100497bool IsMeanSupportedCl(const TensorInfo& input,
498 const TensorInfo& output,
499 const MeanDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100500 Optional<std::string&> reasonIfUnsupported)
narpra0132b90462018-09-13 11:07:48 +0100501{
arovir01085f0a42018-10-08 14:48:19 +0100502 ignore_unused(input);
503 ignore_unused(output);
504 ignore_unused(descriptor);
505 ignore_unused(reasonIfUnsupported);
narpra0132b90462018-09-13 11:07:48 +0100506 return false;
507}
508
arovir01085f0a42018-10-08 14:48:19 +0100509bool IsPadSupportedCl(const TensorInfo& input,
510 const TensorInfo& output,
511 const PadDescriptor& descriptor,
512 Optional<std::string&> reasonIfUnsupported)
513{
514 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate,
515 reasonIfUnsupported,
516 input,
517 output,
518 descriptor);
519}
520
telsoa014fcda012018-03-09 14:13:49 +0000521}