blob: 3a250a69816b73c763f48d5526298fe7dd98aafa [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
8#include <LayerSupportCommon.hpp>
9
telsoa014fcda012018-03-09 14:13:49 +000010#include <armnn/Descriptors.hpp>
11#include <armnn/Types.hpp>
12#include <armnn/Tensor.hpp>
13
14#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015#include "InternalTypes.hpp"
16
17using namespace boost;
18
19namespace armnn
20{
21
arovir011c7c81b2018-10-08 11:34:28 +010022bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
23 const TensorInfo& output,
24 const ActivationDescriptor& descriptor,
25 Optional<std::string&> reasonIfUnsupported) const
26{
arovir01085f0a42018-10-08 14:48:19 +010027 return armnn::IsActivationSupportedRef(input, output, descriptor, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +010028}
29
30bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
31 const TensorInfo& input1,
32 const TensorInfo& output,
33 Optional<std::string&> reasonIfUnsupported) const
34{
arovir01085f0a42018-10-08 14:48:19 +010035 return armnn::IsAdditionSupportedRef(input0, input1, output, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +010036}
37
38bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
39 const TensorInfo& output,
40 const TensorInfo& mean,
41 const TensorInfo& var,
42 const TensorInfo& beta,
43 const TensorInfo& gamma,
44 const BatchNormalizationDescriptor& descriptor,
45 Optional<std::string&> reasonIfUnsupported) const
46{
47 return armnn::IsBatchNormalizationSupportedRef(input,
48 output,
49 mean,
50 var,
51 beta,
52 gamma,
53 descriptor,
arovir01085f0a42018-10-08 14:48:19 +010054 reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +010055}
56
57bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
58 Optional<std::string&> reasonIfUnsupported) const
59{
arovir01085f0a42018-10-08 14:48:19 +010060 return armnn::IsConstantSupportedRef(output, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +010061}
62
63bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
64 const TensorInfo& output,
65 Optional<std::string&> reasonIfUnsupported) const
66{
arovir01085f0a42018-10-08 14:48:19 +010067 return armnn::IsConvertFp16ToFp32SupportedRef(input, output, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +010068}
69
70bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
71 const TensorInfo& output,
72 Optional<std::string&> reasonIfUnsupported) const
73{
arovir01085f0a42018-10-08 14:48:19 +010074 return armnn::IsConvertFp32ToFp16SupportedRef(input, output, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +010075}
76
77bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
78 const TensorInfo& output,
79 const Convolution2dDescriptor& descriptor,
80 const TensorInfo& weights,
81 const Optional<TensorInfo>& biases,
82 Optional<std::string&> reasonIfUnsupported) const
83{
84 return armnn::IsConvolution2dSupportedRef(input,
85 output,
86 descriptor,
87 weights,
88 biases,
arovir01085f0a42018-10-08 14:48:19 +010089 reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +010090}
91
92bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
93 const TensorInfo& output,
94 const DepthwiseConvolution2dDescriptor& descriptor,
95 const TensorInfo& weights,
96 const Optional<TensorInfo>& biases,
97 Optional<std::string&> reasonIfUnsupported) const
98{
99 return armnn::IsDepthwiseConvolutionSupportedRef(input,
100 output,
101 descriptor,
102 weights,
103 biases,
arovir01085f0a42018-10-08 14:48:19 +0100104 reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100105}
106
107bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
108 const TensorInfo& input1,
109 const TensorInfo& output,
110 Optional<std::string&> reasonIfUnsupported) const
111{
arovir01085f0a42018-10-08 14:48:19 +0100112 return armnn::IsDivisionSupportedRef(input0, input1, output, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100113}
114
115bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
116 const FakeQuantizationDescriptor& descriptor,
117 Optional<std::string&> reasonIfUnsupported) const
118{
arovir01085f0a42018-10-08 14:48:19 +0100119 return armnn::IsFakeQuantizationSupportedRef(input, descriptor, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100120}
121
122bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
123 const TensorInfo& output,
124 Optional<std::string&> reasonIfUnsupported) const
125{
arovir01085f0a42018-10-08 14:48:19 +0100126 return armnn::IsFloorSupportedRef(input, output, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100127}
128
129bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
130 const TensorInfo& output,
131 const TensorInfo& weights,
132 const TensorInfo& biases,
133 const FullyConnectedDescriptor& descriptor,
134 Optional<std::string&> reasonIfUnsupported) const
135{
136 return armnn::IsFullyConnectedSupportedRef(input,
137 output,
138 weights,
139 biases,
140 descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100141 reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100142}
143
144bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
145 Optional<std::string&> reasonIfUnsupported) const
146{
arovir01085f0a42018-10-08 14:48:19 +0100147 return armnn::IsInputSupportedRef(input, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100148}
149
150bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
151 const TensorInfo& output,
152 const L2NormalizationDescriptor& descriptor,
153 Optional<std::string&> reasonIfUnsupported) const
154{
arovir01085f0a42018-10-08 14:48:19 +0100155 return armnn::IsL2NormalizationSupportedRef(input, output, descriptor, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100156}
157
158bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
159 const TensorInfo& outputStateIn,
160 const TensorInfo& cellStateIn,
161 const TensorInfo& scratchBuffer,
162 const TensorInfo& outputStateOut,
163 const TensorInfo& cellStateOut,
164 const TensorInfo& output,
165 const LstmDescriptor& descriptor,
166 const TensorInfo& inputToForgetWeights,
167 const TensorInfo& inputToCellWeights,
168 const TensorInfo& inputToOutputWeights,
169 const TensorInfo& recurrentToForgetWeights,
170 const TensorInfo& recurrentToCellWeights,
171 const TensorInfo& recurrentToOutputWeights,
172 const TensorInfo& forgetGateBias,
173 const TensorInfo& cellBias,
174 const TensorInfo& outputGateBias,
175 const TensorInfo* inputToInputWeights,
176 const TensorInfo* recurrentToInputWeights,
177 const TensorInfo* cellToInputWeights,
178 const TensorInfo* inputGateBias,
179 const TensorInfo* projectionWeights,
180 const TensorInfo* projectionBias,
181 const TensorInfo* cellToForgetWeights,
182 const TensorInfo* cellToOutputWeights,
183 Optional<std::string&> reasonIfUnsupported) const
184{
185 return armnn::IsLstmSupportedRef(input,
186 outputStateIn,
187 cellStateIn,
188 scratchBuffer,
189 outputStateOut,
190 cellStateOut,
191 output,
192 descriptor,
193 inputToForgetWeights,
194 inputToCellWeights,
195 inputToOutputWeights,
196 recurrentToForgetWeights,
197 recurrentToCellWeights,
198 recurrentToOutputWeights,
199 forgetGateBias,
200 cellBias,
201 outputGateBias,
202 inputToInputWeights,
203 recurrentToInputWeights,
204 cellToInputWeights,
205 inputGateBias,
206 projectionWeights,
207 projectionBias,
208 cellToForgetWeights,
209 cellToOutputWeights,
arovir01085f0a42018-10-08 14:48:19 +0100210 reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100211}
212
213bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
214 const TensorInfo& output,
215 const MeanDescriptor& descriptor,
216 Optional<std::string&> reasonIfUnsupported) const
217{
arovir01085f0a42018-10-08 14:48:19 +0100218 return armnn::IsMeanSupportedRef(input, output, descriptor,reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100219}
220
221bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
222 const OriginsDescriptor& descriptor,
223 Optional<std::string&> reasonIfUnsupported) const
224{
arovir01085f0a42018-10-08 14:48:19 +0100225 return armnn::IsMergerSupportedRef(inputs, descriptor, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100226}
227
228bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
229 const TensorInfo& input1,
230 const TensorInfo& output,
231 Optional<std::string&> reasonIfUnsupported) const
232{
arovir01085f0a42018-10-08 14:48:19 +0100233 return armnn::IsMultiplicationSupportedRef(input0, input1, output, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100234}
235
236bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
237 const TensorInfo& output,
238 const NormalizationDescriptor& descriptor,
239 Optional<std::string&> reasonIfUnsupported) const
240{
241 return armnn::IsNormalizationSupportedRef(input,
242 output,
243 descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100244 reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100245}
246
247bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
248 Optional<std::string&> reasonIfUnsupported) const
249{
arovir01085f0a42018-10-08 14:48:19 +0100250 return armnn::IsOutputSupportedRef(output, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100251}
252
253bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
254 const TensorInfo& output,
255 const PadDescriptor& descriptor,
256 Optional<std::string&> reasonIfUnsupported) const
257{
arovir01085f0a42018-10-08 14:48:19 +0100258 return armnn::IsPadSupportedRef(input, output, descriptor, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100259}
260
261bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
262 const TensorInfo& output,
263 const PermuteDescriptor& descriptor,
264 Optional<std::string&> reasonIfUnsupported) const
265{
arovir01085f0a42018-10-08 14:48:19 +0100266 return armnn::IsPermuteSupportedRef(input, output, descriptor, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100267}
268
269bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
270 const TensorInfo& output,
271 const Pooling2dDescriptor& descriptor,
272 Optional<std::string&> reasonIfUnsupported) const
273{
arovir01085f0a42018-10-08 14:48:19 +0100274 return armnn::IsPooling2dSupportedRef(input, output, descriptor, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100275}
276
277bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
278 Optional<std::string&> reasonIfUnsupported) const
279{
arovir01085f0a42018-10-08 14:48:19 +0100280 return armnn::IsReshapeSupportedRef(input, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100281}
282
283bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
284 Optional<std::string&> reasonIfUnsupported) const
285{
arovir01085f0a42018-10-08 14:48:19 +0100286 return armnn::IsResizeBilinearSupportedRef(input, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100287}
288
289bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
290 const TensorInfo& output,
291 const SoftmaxDescriptor& descriptor,
292 Optional<std::string&> reasonIfUnsupported) const
293{
arovir01085f0a42018-10-08 14:48:19 +0100294 return armnn::IsSoftmaxSupportedRef(input, output, descriptor, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100295}
296
297bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
298 const ViewsDescriptor& descriptor,
299 Optional<std::string&> reasonIfUnsupported) const
300{
arovir01085f0a42018-10-08 14:48:19 +0100301 return armnn::IsSplitterSupportedRef(input, descriptor, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100302}
303
304bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
305 const TensorInfo& input1,
306 const TensorInfo& output,
307 Optional<std::string&> reasonIfUnsupported) const
308{
arovir01085f0a42018-10-08 14:48:19 +0100309 return armnn::IsSubtractionSupportedRef(input0, input1, output, reasonIfUnsupported);
arovir011c7c81b2018-10-08 11:34:28 +0100310}
311
312//
313// Implementation functions
314//
315// TODO: Functions kept for backward compatibility. Remove once transition to plugable backends is complete!
316
telsoa014fcda012018-03-09 14:13:49 +0000317template<typename Float32Func, typename Uint8Func, typename ... Params>
arovir01085f0a42018-10-08 14:48:19 +0100318bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
telsoa014fcda012018-03-09 14:13:49 +0000319 DataType dataType,
320 Float32Func floatFuncPtr,
321 Uint8Func uint8FuncPtr,
322 Params&&... params)
323{
324 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
325 dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100326 &FalseFunc<Params...>,
telsoa014fcda012018-03-09 14:13:49 +0000327 floatFuncPtr,
328 uint8FuncPtr,
329 std::forward<Params>(params)...);
330}
331
332bool IsActivationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100333 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000334 const ActivationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100335 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000336{
telsoa01c577f2c2018-08-31 09:22:23 +0100337 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000338 ignore_unused(descriptor);
339 return IsSupportedForDataTypeRef(reasonIfUnsupported,
340 input.GetDataType(),
341 &TrueFunc<>,
342 &TrueFunc<>);
343}
344
345bool IsAdditionSupportedRef(const TensorInfo& input0,
346 const TensorInfo& input1,
347 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100348 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000349{
350 ignore_unused(input1);
351 ignore_unused(output);
352 return IsSupportedForDataTypeRef(reasonIfUnsupported,
353 input0.GetDataType(),
354 &TrueFunc<>,
355 &TrueFunc<>);
356}
357
358bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100359 const TensorInfo& output,
360 const TensorInfo& mean,
361 const TensorInfo& var,
362 const TensorInfo& beta,
363 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +0000364 const BatchNormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100365 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000366{
367 ignore_unused(descriptor);
368 return IsSupportedForDataTypeRef(reasonIfUnsupported,
369 input.GetDataType(),
370 &TrueFunc<>,
371 &TrueFunc<>);
372}
373
374bool IsConstantSupportedRef(const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100375 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000376{
377 return IsSupportedForDataTypeRef(reasonIfUnsupported,
378 output.GetDataType(),
379 &TrueFunc<>,
380 &TrueFunc<>);
381}
382
383bool IsConvolution2dSupportedRef(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +0100384 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000385 const Convolution2dDescriptor& descriptor,
386 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100387 const Optional<TensorInfo>& biases,
arovir01085f0a42018-10-08 14:48:19 +0100388 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000389{
390 ignore_unused(descriptor);
surmeh013537c2c2018-05-18 16:31:43 +0100391 ignore_unused(output);
392 ignore_unused(weights);
393 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000394 return IsSupportedForDataTypeRef(reasonIfUnsupported,
395 input.GetDataType(),
396 &TrueFunc<>,
397 &TrueFunc<>);
398}
399
400bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100401 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000402 const DepthwiseConvolution2dDescriptor& descriptor,
403 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100404 const Optional<TensorInfo>& biases,
arovir01085f0a42018-10-08 14:48:19 +0100405 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000406{
telsoa01c577f2c2018-08-31 09:22:23 +0100407 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000408 ignore_unused(descriptor);
409 ignore_unused(weights);
telsoa01c577f2c2018-08-31 09:22:23 +0100410 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000411 return IsSupportedForDataTypeRef(reasonIfUnsupported,
412 input.GetDataType(),
413 &TrueFunc<>,
414 &TrueFunc<>);
415}
416
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100417bool IsDivisionSupportedRef(const TensorInfo& input0,
418 const TensorInfo& input1,
419 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100420 Optional<std::string&> reasonIfUnsupported)
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100421{
422 ignore_unused(input1);
423 ignore_unused(output);
424 return IsSupportedForDataTypeRef(reasonIfUnsupported,
425 input0.GetDataType(),
426 &TrueFunc<>,
427 &TrueFunc<>);
428}
429
David Beckc2044fe2018-09-05 15:00:38 +0100430bool IsSubtractionSupportedRef(const TensorInfo& input0,
431 const TensorInfo& input1,
432 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100433 Optional<std::string&> reasonIfUnsupported)
David Beckc2044fe2018-09-05 15:00:38 +0100434{
David Beckf195f032018-09-06 16:46:34 +0100435 ignore_unused(input1);
436 ignore_unused(output);
437 return IsSupportedForDataTypeRef(reasonIfUnsupported,
438 input0.GetDataType(),
439 &TrueFunc<>,
440 &TrueFunc<>);
David Beckc2044fe2018-09-05 15:00:38 +0100441}
442
telsoa014fcda012018-03-09 14:13:49 +0000443bool IsFullyConnectedSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100444 const TensorInfo& output,
445 const TensorInfo& weights,
446 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000447 const FullyConnectedDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100448 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000449{
telsoa01c577f2c2018-08-31 09:22:23 +0100450 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000451 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100452 ignore_unused(weights);
453 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000454 return IsSupportedForDataTypeRef(reasonIfUnsupported,
455 input.GetDataType(),
456 &TrueFunc<>,
457 &TrueFunc<>);
458}
459
460bool IsInputSupportedRef(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100461 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000462{
463 return IsSupportedForDataTypeRef(reasonIfUnsupported,
464 input.GetDataType(),
465 &TrueFunc<>,
466 &TrueFunc<>);
467}
468
469bool IsL2NormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100470 const TensorInfo& output,
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100471 const L2NormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100472 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000473{
telsoa01c577f2c2018-08-31 09:22:23 +0100474 ignore_unused(output);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100475 ignore_unused(descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000476 return IsSupportedForDataTypeRef(reasonIfUnsupported,
477 input.GetDataType(),
478 &TrueFunc<>,
479 &FalseFuncU8<>);
480}
481
482bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
483 const OriginsDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100484 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000485{
486 ignore_unused(descriptor);
487 return IsSupportedForDataTypeRef(reasonIfUnsupported,
488 inputs[0]->GetDataType(),
489 &TrueFunc<>,
490 &TrueFunc<>);
491}
492
493bool IsMultiplicationSupportedRef(const TensorInfo& input0,
494 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100495 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100496 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000497{
498 ignore_unused(input1);
telsoa01c577f2c2018-08-31 09:22:23 +0100499 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000500 return IsSupportedForDataTypeRef(reasonIfUnsupported,
501 input0.GetDataType(),
502 &TrueFunc<>,
503 &TrueFunc<>);
504}
505
506bool IsNormalizationSupportedRef(const TensorInfo& input,
507 const TensorInfo& output,
508 const NormalizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100509 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000510{
511 ignore_unused(descriptor);
512 return IsSupportedForDataTypeRef(reasonIfUnsupported,
513 input.GetDataType(),
514 &TrueFunc<>,
515 &FalseFuncU8<>);
516}
517
518bool IsOutputSupportedRef(const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100519 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000520{
521 return IsSupportedForDataTypeRef(reasonIfUnsupported,
522 output.GetDataType(),
523 &TrueFunc<>,
524 &TrueFunc<>);
525}
526
527bool IsPermuteSupportedRef(const TensorInfo& input,
528 const TensorInfo& output,
529 const PermuteDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100530 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000531{
532 ignore_unused(descriptor);
533 return IsSupportedForDataTypeRef(reasonIfUnsupported,
534 input.GetDataType(),
535 &TrueFunc<>,
536 &TrueFunc<>);
537}
538
539bool IsPooling2dSupportedRef(const TensorInfo& input,
540 const TensorInfo& output,
541 const Pooling2dDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100542 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000543{
544 ignore_unused(descriptor);
545 return IsSupportedForDataTypeRef(reasonIfUnsupported,
546 input.GetDataType(),
547 &TrueFunc<>,
548 &TrueFunc<>);
549}
550
551bool IsResizeBilinearSupportedRef(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100552 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000553{
554 return IsSupportedForDataTypeRef(reasonIfUnsupported,
555 input.GetDataType(),
556 &TrueFunc<>,
557 &TrueFunc<>);
558}
559
560bool IsSoftmaxSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100561 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000562 const SoftmaxDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100563 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000564{
telsoa01c577f2c2018-08-31 09:22:23 +0100565 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000566 ignore_unused(descriptor);
567 return IsSupportedForDataTypeRef(reasonIfUnsupported,
568 input.GetDataType(),
569 &TrueFunc<>,
570 &TrueFunc<>);
571}
572
573bool IsSplitterSupportedRef(const TensorInfo& input,
574 const ViewsDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100575 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000576{
577 ignore_unused(descriptor);
578 return IsSupportedForDataTypeRef(reasonIfUnsupported,
579 input.GetDataType(),
580 &TrueFunc<>,
581 &TrueFunc<>);
582}
583
584bool IsFakeQuantizationSupportedRef(const TensorInfo& input,
585 const FakeQuantizationDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100586 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000587{
588 ignore_unused(descriptor);
589 return IsSupportedForDataTypeRef(reasonIfUnsupported,
590 input.GetDataType(),
591 &TrueFunc<>,
592 &FalseFuncU8<>);
593}
594
595bool IsReshapeSupportedRef(const TensorInfo& input,
arovir01085f0a42018-10-08 14:48:19 +0100596 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000597{
598 return IsSupportedForDataTypeRef(reasonIfUnsupported,
599 input.GetDataType(),
600 &TrueFunc<>,
601 &TrueFunc<>);
602}
603
604bool IsFloorSupportedRef(const TensorInfo& input,
605 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100606 Optional<std::string&> reasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +0000607{
608 ignore_unused(output);
609 return IsSupportedForDataTypeRef(reasonIfUnsupported,
610 input.GetDataType(),
611 &TrueFunc<>,
612 &FalseFuncU8<>);
613}
614
arovir01085f0a42018-10-08 14:48:19 +0100615bool IsLstmSupportedRef(const TensorInfo& input,
616 const TensorInfo& outputStateIn,
617 const TensorInfo& cellStateIn,
618 const TensorInfo& scratchBuffer,
619 const TensorInfo& outputStateOut,
620 const TensorInfo& cellStateOut,
621 const TensorInfo& output,
622 const LstmDescriptor& descriptor,
623 const TensorInfo& inputToForgetWeights,
624 const TensorInfo& inputToCellWeights,
625 const TensorInfo& inputToOutputWeights,
626 const TensorInfo& recurrentToForgetWeights,
627 const TensorInfo& recurrentToCellWeights,
628 const TensorInfo& recurrentToOutputWeights,
629 const TensorInfo& forgetGateBias,
630 const TensorInfo& cellBias,
631 const TensorInfo& outputGateBias,
632 const TensorInfo* inputToInputWeights,
633 const TensorInfo* recurrentToInputWeights,
634 const TensorInfo* cellToInputWeights,
635 const TensorInfo* inputGateBias,
636 const TensorInfo* projectionWeights,
637 const TensorInfo* projectionBias,
638 const TensorInfo* cellToForgetWeights,
639 const TensorInfo* cellToOutputWeights,
640 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100641{
642 ignore_unused(input);
643 ignore_unused(outputStateIn);
644 ignore_unused(cellStateIn);
645 ignore_unused(scratchBuffer);
646 ignore_unused(outputStateOut);
647 ignore_unused(cellStateOut);
648 ignore_unused(output);
649 ignore_unused(descriptor);
650 ignore_unused(inputToForgetWeights);
651 ignore_unused(inputToCellWeights);
652 ignore_unused(inputToOutputWeights);
653 ignore_unused(recurrentToForgetWeights);
654 ignore_unused(recurrentToCellWeights);
655 ignore_unused(recurrentToOutputWeights);
656 ignore_unused(forgetGateBias);
657 ignore_unused(cellBias);
658 ignore_unused(outputGateBias);
659 ignore_unused(inputToInputWeights);
660 ignore_unused(recurrentToInputWeights);
661 ignore_unused(cellToInputWeights);
662 ignore_unused(inputGateBias);
663 ignore_unused(projectionWeights);
664 ignore_unused(projectionBias);
665 ignore_unused(cellToForgetWeights);
666 ignore_unused(cellToOutputWeights);
arovir01085f0a42018-10-08 14:48:19 +0100667 ignore_unused(reasonIfUnsupported);
telsoa01c577f2c2018-08-31 09:22:23 +0100668 return false;
669}
670
671bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input,
672 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100673 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100674{
675 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
676 input.GetDataType(),
677 &TrueFunc<>,
678 &FalseInputFuncF32<>,
679 &FalseFuncU8<>) &&
680 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
681 output.GetDataType(),
682 &FalseOutputFuncF16<>,
683 &TrueFunc<>,
684 &FalseFuncU8<>));
685}
686
687bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input,
688 const TensorInfo& output,
arovir01085f0a42018-10-08 14:48:19 +0100689 Optional<std::string&> reasonIfUnsupported)
telsoa01c577f2c2018-08-31 09:22:23 +0100690{
691 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
692 input.GetDataType(),
693 &FalseInputFuncF16<>,
694 &TrueFunc<>,
695 &FalseFuncU8<>) &&
696 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
697 output.GetDataType(),
698 &TrueFunc<>,
699 &FalseOutputFuncF32<>,
700 &FalseFuncU8<>));
701}
702
narpra0132b90462018-09-13 11:07:48 +0100703bool IsMeanSupportedRef(const TensorInfo& input,
704 const TensorInfo& output,
705 const MeanDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100706 Optional<std::string&> reasonIfUnsupported)
narpra0132b90462018-09-13 11:07:48 +0100707{
narpra011e4c31d2018-09-28 11:07:51 +0100708 ignore_unused(output);
709 ignore_unused(descriptor);
710 return IsSupportedForDataTypeRef(reasonIfUnsupported,
711 input.GetDataType(),
712 &TrueFunc<>,
713 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100714}
715
Nina Drozd661dfa72018-10-02 11:14:17 +0100716bool IsPadSupportedRef(const TensorInfo& input,
717 const TensorInfo& output,
718 const PadDescriptor& descriptor,
arovir01085f0a42018-10-08 14:48:19 +0100719 Optional<std::string&> reasonIfUnsupported)
Nina Drozd661dfa72018-10-02 11:14:17 +0100720{
721 ignore_unused(output);
722 ignore_unused(descriptor);
723 return false;
724}
725
arovir011c7c81b2018-10-08 11:34:28 +0100726} // namespace armnn