blob: e6b1442e4dd67d73fa4d20cc94082b363a6de02d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "LayerSupportCommon.hpp"
7#include "RefLayerSupport.hpp"
8#include <armnn/Descriptors.hpp>
9#include <armnn/Types.hpp>
10#include <armnn/Tensor.hpp>
11
12#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013#include "InternalTypes.hpp"
14
15using namespace boost;
16
17namespace armnn
18{
19
arovir011c7c81b2018-10-08 11:34:28 +010020namespace
21{
22
23std::string* GetReasonIfUnsupportedPtr(const Optional<std::string&>& reasonIfUnsupported)
24{
25 return reasonIfUnsupported ? &reasonIfUnsupported.value() : nullptr;
26}
27
28} // anonymous namespace
29
30bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
31 const TensorInfo& output,
32 const ActivationDescriptor& descriptor,
33 Optional<std::string&> reasonIfUnsupported) const
34{
35 return armnn::IsActivationSupportedRef(input, output, descriptor, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
36}
37
38bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
39 const TensorInfo& input1,
40 const TensorInfo& output,
41 Optional<std::string&> reasonIfUnsupported) const
42{
43 return armnn::IsAdditionSupportedRef(input0,
44 input1,
45 output,
46 GetReasonIfUnsupportedPtr(reasonIfUnsupported));
47}
48
49bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
50 const TensorInfo& output,
51 const TensorInfo& mean,
52 const TensorInfo& var,
53 const TensorInfo& beta,
54 const TensorInfo& gamma,
55 const BatchNormalizationDescriptor& descriptor,
56 Optional<std::string&> reasonIfUnsupported) const
57{
58 return armnn::IsBatchNormalizationSupportedRef(input,
59 output,
60 mean,
61 var,
62 beta,
63 gamma,
64 descriptor,
65 GetReasonIfUnsupportedPtr(reasonIfUnsupported));
66}
67
68bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
69 Optional<std::string&> reasonIfUnsupported) const
70{
71 return armnn::IsConstantSupportedRef(output, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
72}
73
74bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
75 const TensorInfo& output,
76 Optional<std::string&> reasonIfUnsupported) const
77{
78 return armnn::IsConvertFp16ToFp32SupportedRef(input, output, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
79}
80
81bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
82 const TensorInfo& output,
83 Optional<std::string&> reasonIfUnsupported) const
84{
85 return armnn::IsConvertFp32ToFp16SupportedRef(input, output, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
86}
87
88bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
89 const TensorInfo& output,
90 const Convolution2dDescriptor& descriptor,
91 const TensorInfo& weights,
92 const Optional<TensorInfo>& biases,
93 Optional<std::string&> reasonIfUnsupported) const
94{
95 return armnn::IsConvolution2dSupportedRef(input,
96 output,
97 descriptor,
98 weights,
99 biases,
100 GetReasonIfUnsupportedPtr(reasonIfUnsupported));
101}
102
103bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
104 const TensorInfo& output,
105 const DepthwiseConvolution2dDescriptor& descriptor,
106 const TensorInfo& weights,
107 const Optional<TensorInfo>& biases,
108 Optional<std::string&> reasonIfUnsupported) const
109{
110 return armnn::IsDepthwiseConvolutionSupportedRef(input,
111 output,
112 descriptor,
113 weights,
114 biases,
115 GetReasonIfUnsupportedPtr(reasonIfUnsupported));
116}
117
118bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
119 const TensorInfo& input1,
120 const TensorInfo& output,
121 Optional<std::string&> reasonIfUnsupported) const
122{
123 return armnn::IsDivisionSupportedRef(input0, input1, output, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
124}
125
126bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
127 const FakeQuantizationDescriptor& descriptor,
128 Optional<std::string&> reasonIfUnsupported) const
129{
130 return armnn::IsFakeQuantizationSupportedRef(input, descriptor, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
131}
132
133bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
134 const TensorInfo& output,
135 Optional<std::string&> reasonIfUnsupported) const
136{
137 return armnn::IsFloorSupportedRef(input, output, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
138}
139
140bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
141 const TensorInfo& output,
142 const TensorInfo& weights,
143 const TensorInfo& biases,
144 const FullyConnectedDescriptor& descriptor,
145 Optional<std::string&> reasonIfUnsupported) const
146{
147 return armnn::IsFullyConnectedSupportedRef(input,
148 output,
149 weights,
150 biases,
151 descriptor,
152 GetReasonIfUnsupportedPtr(reasonIfUnsupported));
153}
154
155bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
156 Optional<std::string&> reasonIfUnsupported) const
157{
158 return armnn::IsInputSupportedRef(input, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
159}
160
161bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
162 const TensorInfo& output,
163 const L2NormalizationDescriptor& descriptor,
164 Optional<std::string&> reasonIfUnsupported) const
165{
166 return armnn::IsL2NormalizationSupportedRef(input,
167 output,
168 descriptor,
169 GetReasonIfUnsupportedPtr(reasonIfUnsupported));
170}
171
172bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
173 const TensorInfo& outputStateIn,
174 const TensorInfo& cellStateIn,
175 const TensorInfo& scratchBuffer,
176 const TensorInfo& outputStateOut,
177 const TensorInfo& cellStateOut,
178 const TensorInfo& output,
179 const LstmDescriptor& descriptor,
180 const TensorInfo& inputToForgetWeights,
181 const TensorInfo& inputToCellWeights,
182 const TensorInfo& inputToOutputWeights,
183 const TensorInfo& recurrentToForgetWeights,
184 const TensorInfo& recurrentToCellWeights,
185 const TensorInfo& recurrentToOutputWeights,
186 const TensorInfo& forgetGateBias,
187 const TensorInfo& cellBias,
188 const TensorInfo& outputGateBias,
189 const TensorInfo* inputToInputWeights,
190 const TensorInfo* recurrentToInputWeights,
191 const TensorInfo* cellToInputWeights,
192 const TensorInfo* inputGateBias,
193 const TensorInfo* projectionWeights,
194 const TensorInfo* projectionBias,
195 const TensorInfo* cellToForgetWeights,
196 const TensorInfo* cellToOutputWeights,
197 Optional<std::string&> reasonIfUnsupported) const
198{
199 return armnn::IsLstmSupportedRef(input,
200 outputStateIn,
201 cellStateIn,
202 scratchBuffer,
203 outputStateOut,
204 cellStateOut,
205 output,
206 descriptor,
207 inputToForgetWeights,
208 inputToCellWeights,
209 inputToOutputWeights,
210 recurrentToForgetWeights,
211 recurrentToCellWeights,
212 recurrentToOutputWeights,
213 forgetGateBias,
214 cellBias,
215 outputGateBias,
216 inputToInputWeights,
217 recurrentToInputWeights,
218 cellToInputWeights,
219 inputGateBias,
220 projectionWeights,
221 projectionBias,
222 cellToForgetWeights,
223 cellToOutputWeights,
224 GetReasonIfUnsupportedPtr(reasonIfUnsupported));
225}
226
227bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
228 const TensorInfo& output,
229 const MeanDescriptor& descriptor,
230 Optional<std::string&> reasonIfUnsupported) const
231{
232 return armnn::IsMeanSupportedRef(input, output, descriptor, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
233}
234
235bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
236 const OriginsDescriptor& descriptor,
237 Optional<std::string&> reasonIfUnsupported) const
238{
239 return armnn::IsMergerSupportedRef(inputs, descriptor, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
240}
241
242bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
243 const TensorInfo& input1,
244 const TensorInfo& output,
245 Optional<std::string&> reasonIfUnsupported) const
246{
247 return armnn::IsMultiplicationSupportedRef(input0, input1, output, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
248}
249
250bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
251 const TensorInfo& output,
252 const NormalizationDescriptor& descriptor,
253 Optional<std::string&> reasonIfUnsupported) const
254{
255 return armnn::IsNormalizationSupportedRef(input,
256 output,
257 descriptor,
258 GetReasonIfUnsupportedPtr(reasonIfUnsupported));
259}
260
261bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
262 Optional<std::string&> reasonIfUnsupported) const
263{
264 return armnn::IsOutputSupportedRef(output, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
265}
266
267bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
268 const TensorInfo& output,
269 const PadDescriptor& descriptor,
270 Optional<std::string&> reasonIfUnsupported) const
271{
272 return armnn::IsPadSupportedRef(input, output, descriptor, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
273}
274
275bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
276 const TensorInfo& output,
277 const PermuteDescriptor& descriptor,
278 Optional<std::string&> reasonIfUnsupported) const
279{
280 return armnn::IsPermuteSupportedRef(input, output, descriptor, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
281}
282
283bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
284 const TensorInfo& output,
285 const Pooling2dDescriptor& descriptor,
286 Optional<std::string&> reasonIfUnsupported) const
287{
288 return armnn::IsPooling2dSupportedRef(input, output, descriptor, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
289}
290
291bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
292 Optional<std::string&> reasonIfUnsupported) const
293{
294 return armnn::IsReshapeSupportedRef(input, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
295}
296
297bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
298 Optional<std::string&> reasonIfUnsupported) const
299{
300 return armnn::IsResizeBilinearSupportedRef(input, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
301}
302
303bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
304 const TensorInfo& output,
305 const SoftmaxDescriptor& descriptor,
306 Optional<std::string&> reasonIfUnsupported) const
307{
308 return armnn::IsSoftmaxSupportedRef(input, output, descriptor, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
309}
310
311bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
312 const ViewsDescriptor& descriptor,
313 Optional<std::string&> reasonIfUnsupported) const
314{
315 return armnn::IsSplitterSupportedRef(input, descriptor, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
316}
317
318bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
319 const TensorInfo& input1,
320 const TensorInfo& output,
321 Optional<std::string&> reasonIfUnsupported) const
322{
323 return armnn::IsSubtractionSupportedRef(input0, input1, output, GetReasonIfUnsupportedPtr(reasonIfUnsupported));
324}
325
326//
327// Implementation functions
328//
329// TODO: Functions kept for backward compatibility. Remove once transition to plugable backends is complete!
330
telsoa014fcda012018-03-09 14:13:49 +0000331template<typename Float32Func, typename Uint8Func, typename ... Params>
332bool IsSupportedForDataTypeRef(std::string* reasonIfUnsupported,
333 DataType dataType,
334 Float32Func floatFuncPtr,
335 Uint8Func uint8FuncPtr,
336 Params&&... params)
337{
338 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
339 dataType,
telsoa01c577f2c2018-08-31 09:22:23 +0100340 &FalseFunc<Params...>,
telsoa014fcda012018-03-09 14:13:49 +0000341 floatFuncPtr,
342 uint8FuncPtr,
343 std::forward<Params>(params)...);
344}
345
346bool IsActivationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100347 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000348 const ActivationDescriptor& descriptor,
349 std::string* reasonIfUnsupported)
350{
telsoa01c577f2c2018-08-31 09:22:23 +0100351 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000352 ignore_unused(descriptor);
353 return IsSupportedForDataTypeRef(reasonIfUnsupported,
354 input.GetDataType(),
355 &TrueFunc<>,
356 &TrueFunc<>);
357}
358
359bool IsAdditionSupportedRef(const TensorInfo& input0,
360 const TensorInfo& input1,
361 const TensorInfo& output,
362 std::string* reasonIfUnsupported)
363{
364 ignore_unused(input1);
365 ignore_unused(output);
366 return IsSupportedForDataTypeRef(reasonIfUnsupported,
367 input0.GetDataType(),
368 &TrueFunc<>,
369 &TrueFunc<>);
370}
371
372bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100373 const TensorInfo& output,
374 const TensorInfo& mean,
375 const TensorInfo& var,
376 const TensorInfo& beta,
377 const TensorInfo& gamma,
telsoa014fcda012018-03-09 14:13:49 +0000378 const BatchNormalizationDescriptor& descriptor,
379 std::string* reasonIfUnsupported)
380{
381 ignore_unused(descriptor);
382 return IsSupportedForDataTypeRef(reasonIfUnsupported,
383 input.GetDataType(),
384 &TrueFunc<>,
385 &TrueFunc<>);
386}
387
388bool IsConstantSupportedRef(const TensorInfo& output,
389 std::string* reasonIfUnsupported)
390{
391 return IsSupportedForDataTypeRef(reasonIfUnsupported,
392 output.GetDataType(),
393 &TrueFunc<>,
394 &TrueFunc<>);
395}
396
397bool IsConvolution2dSupportedRef(const TensorInfo& input,
surmeh013537c2c2018-05-18 16:31:43 +0100398 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000399 const Convolution2dDescriptor& descriptor,
400 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100401 const Optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000402 std::string* reasonIfUnsupported)
403{
404 ignore_unused(descriptor);
surmeh013537c2c2018-05-18 16:31:43 +0100405 ignore_unused(output);
406 ignore_unused(weights);
407 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000408 return IsSupportedForDataTypeRef(reasonIfUnsupported,
409 input.GetDataType(),
410 &TrueFunc<>,
411 &TrueFunc<>);
412}
413
414bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100415 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000416 const DepthwiseConvolution2dDescriptor& descriptor,
417 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +0100418 const Optional<TensorInfo>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000419 std::string* reasonIfUnsupported)
420{
telsoa01c577f2c2018-08-31 09:22:23 +0100421 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000422 ignore_unused(descriptor);
423 ignore_unused(weights);
telsoa01c577f2c2018-08-31 09:22:23 +0100424 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000425 return IsSupportedForDataTypeRef(reasonIfUnsupported,
426 input.GetDataType(),
427 &TrueFunc<>,
428 &TrueFunc<>);
429}
430
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100431bool IsDivisionSupportedRef(const TensorInfo& input0,
432 const TensorInfo& input1,
433 const TensorInfo& output,
434 std::string* reasonIfUnsupported)
435{
436 ignore_unused(input1);
437 ignore_unused(output);
438 return IsSupportedForDataTypeRef(reasonIfUnsupported,
439 input0.GetDataType(),
440 &TrueFunc<>,
441 &TrueFunc<>);
442}
443
David Beckc2044fe2018-09-05 15:00:38 +0100444bool IsSubtractionSupportedRef(const TensorInfo& input0,
445 const TensorInfo& input1,
446 const TensorInfo& output,
447 std::string* reasonIfUnsupported)
448{
David Beckf195f032018-09-06 16:46:34 +0100449 ignore_unused(input1);
450 ignore_unused(output);
451 return IsSupportedForDataTypeRef(reasonIfUnsupported,
452 input0.GetDataType(),
453 &TrueFunc<>,
454 &TrueFunc<>);
David Beckc2044fe2018-09-05 15:00:38 +0100455}
456
telsoa014fcda012018-03-09 14:13:49 +0000457bool IsFullyConnectedSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100458 const TensorInfo& output,
459 const TensorInfo& weights,
460 const TensorInfo& biases,
telsoa014fcda012018-03-09 14:13:49 +0000461 const FullyConnectedDescriptor& descriptor,
462 std::string* reasonIfUnsupported)
463{
telsoa01c577f2c2018-08-31 09:22:23 +0100464 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000465 ignore_unused(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +0100466 ignore_unused(weights);
467 ignore_unused(biases);
telsoa014fcda012018-03-09 14:13:49 +0000468 return IsSupportedForDataTypeRef(reasonIfUnsupported,
469 input.GetDataType(),
470 &TrueFunc<>,
471 &TrueFunc<>);
472}
473
474bool IsInputSupportedRef(const TensorInfo& input,
475 std::string* reasonIfUnsupported)
476{
477 return IsSupportedForDataTypeRef(reasonIfUnsupported,
478 input.GetDataType(),
479 &TrueFunc<>,
480 &TrueFunc<>);
481}
482
483bool IsL2NormalizationSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100484 const TensorInfo& output,
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100485 const L2NormalizationDescriptor& descriptor,
telsoa014fcda012018-03-09 14:13:49 +0000486 std::string* reasonIfUnsupported)
487{
telsoa01c577f2c2018-08-31 09:22:23 +0100488 ignore_unused(output);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100489 ignore_unused(descriptor);
telsoa014fcda012018-03-09 14:13:49 +0000490 return IsSupportedForDataTypeRef(reasonIfUnsupported,
491 input.GetDataType(),
492 &TrueFunc<>,
493 &FalseFuncU8<>);
494}
495
496bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
497 const OriginsDescriptor& descriptor,
498 std::string* reasonIfUnsupported)
499{
500 ignore_unused(descriptor);
501 return IsSupportedForDataTypeRef(reasonIfUnsupported,
502 inputs[0]->GetDataType(),
503 &TrueFunc<>,
504 &TrueFunc<>);
505}
506
507bool IsMultiplicationSupportedRef(const TensorInfo& input0,
508 const TensorInfo& input1,
telsoa01c577f2c2018-08-31 09:22:23 +0100509 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000510 std::string* reasonIfUnsupported)
511{
512 ignore_unused(input1);
telsoa01c577f2c2018-08-31 09:22:23 +0100513 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000514 return IsSupportedForDataTypeRef(reasonIfUnsupported,
515 input0.GetDataType(),
516 &TrueFunc<>,
517 &TrueFunc<>);
518}
519
520bool IsNormalizationSupportedRef(const TensorInfo& input,
521 const TensorInfo& output,
522 const NormalizationDescriptor& descriptor,
523 std::string* reasonIfUnsupported)
524{
525 ignore_unused(descriptor);
526 return IsSupportedForDataTypeRef(reasonIfUnsupported,
527 input.GetDataType(),
528 &TrueFunc<>,
529 &FalseFuncU8<>);
530}
531
532bool IsOutputSupportedRef(const TensorInfo& output,
533 std::string* reasonIfUnsupported)
534{
535 return IsSupportedForDataTypeRef(reasonIfUnsupported,
536 output.GetDataType(),
537 &TrueFunc<>,
538 &TrueFunc<>);
539}
540
541bool IsPermuteSupportedRef(const TensorInfo& input,
542 const TensorInfo& output,
543 const PermuteDescriptor& descriptor,
544 std::string* reasonIfUnsupported)
545{
546 ignore_unused(descriptor);
547 return IsSupportedForDataTypeRef(reasonIfUnsupported,
548 input.GetDataType(),
549 &TrueFunc<>,
550 &TrueFunc<>);
551}
552
553bool IsPooling2dSupportedRef(const TensorInfo& input,
554 const TensorInfo& output,
555 const Pooling2dDescriptor& descriptor,
556 std::string* reasonIfUnsupported)
557{
558 ignore_unused(descriptor);
559 return IsSupportedForDataTypeRef(reasonIfUnsupported,
560 input.GetDataType(),
561 &TrueFunc<>,
562 &TrueFunc<>);
563}
564
565bool IsResizeBilinearSupportedRef(const TensorInfo& input,
566 std::string* reasonIfUnsupported)
567{
568 return IsSupportedForDataTypeRef(reasonIfUnsupported,
569 input.GetDataType(),
570 &TrueFunc<>,
571 &TrueFunc<>);
572}
573
574bool IsSoftmaxSupportedRef(const TensorInfo& input,
telsoa01c577f2c2018-08-31 09:22:23 +0100575 const TensorInfo& output,
telsoa014fcda012018-03-09 14:13:49 +0000576 const SoftmaxDescriptor& descriptor,
577 std::string* reasonIfUnsupported)
578{
telsoa01c577f2c2018-08-31 09:22:23 +0100579 ignore_unused(output);
telsoa014fcda012018-03-09 14:13:49 +0000580 ignore_unused(descriptor);
581 return IsSupportedForDataTypeRef(reasonIfUnsupported,
582 input.GetDataType(),
583 &TrueFunc<>,
584 &TrueFunc<>);
585}
586
587bool IsSplitterSupportedRef(const TensorInfo& input,
588 const ViewsDescriptor& descriptor,
589 std::string* reasonIfUnsupported)
590{
591 ignore_unused(descriptor);
592 return IsSupportedForDataTypeRef(reasonIfUnsupported,
593 input.GetDataType(),
594 &TrueFunc<>,
595 &TrueFunc<>);
596}
597
598bool IsFakeQuantizationSupportedRef(const TensorInfo& input,
599 const FakeQuantizationDescriptor& descriptor,
600 std::string* reasonIfUnsupported)
601{
602 ignore_unused(descriptor);
603 return IsSupportedForDataTypeRef(reasonIfUnsupported,
604 input.GetDataType(),
605 &TrueFunc<>,
606 &FalseFuncU8<>);
607}
608
609bool IsReshapeSupportedRef(const TensorInfo& input,
610 std::string* reasonIfUnsupported)
611{
612 return IsSupportedForDataTypeRef(reasonIfUnsupported,
613 input.GetDataType(),
614 &TrueFunc<>,
615 &TrueFunc<>);
616}
617
618bool IsFloorSupportedRef(const TensorInfo& input,
619 const TensorInfo& output,
620 std::string* reasonIfUnsupported)
621{
622 ignore_unused(output);
623 return IsSupportedForDataTypeRef(reasonIfUnsupported,
624 input.GetDataType(),
625 &TrueFunc<>,
626 &FalseFuncU8<>);
627}
628
telsoa01c577f2c2018-08-31 09:22:23 +0100629bool IsLstmSupportedRef(const TensorInfo& input, const TensorInfo& outputStateIn,
630 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
631 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
632 const TensorInfo& output, const LstmDescriptor& descriptor,
633 const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
634 const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
635 const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
636 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
637 const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
638 const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
639 const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
640 const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
641 const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
642{
643 ignore_unused(input);
644 ignore_unused(outputStateIn);
645 ignore_unused(cellStateIn);
646 ignore_unused(scratchBuffer);
647 ignore_unused(outputStateOut);
648 ignore_unused(cellStateOut);
649 ignore_unused(output);
650 ignore_unused(descriptor);
651 ignore_unused(inputToForgetWeights);
652 ignore_unused(inputToCellWeights);
653 ignore_unused(inputToOutputWeights);
654 ignore_unused(recurrentToForgetWeights);
655 ignore_unused(recurrentToCellWeights);
656 ignore_unused(recurrentToOutputWeights);
657 ignore_unused(forgetGateBias);
658 ignore_unused(cellBias);
659 ignore_unused(outputGateBias);
660 ignore_unused(inputToInputWeights);
661 ignore_unused(recurrentToInputWeights);
662 ignore_unused(cellToInputWeights);
663 ignore_unused(inputGateBias);
664 ignore_unused(projectionWeights);
665 ignore_unused(projectionBias);
666 ignore_unused(cellToForgetWeights);
667 ignore_unused(cellToOutputWeights);
668 return false;
669}
670
671bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input,
672 const TensorInfo& output,
673 std::string* reasonIfUnsupported)
674{
675 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
676 input.GetDataType(),
677 &TrueFunc<>,
678 &FalseInputFuncF32<>,
679 &FalseFuncU8<>) &&
680 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
681 output.GetDataType(),
682 &FalseOutputFuncF16<>,
683 &TrueFunc<>,
684 &FalseFuncU8<>));
685}
686
687bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input,
688 const TensorInfo& output,
689 std::string* reasonIfUnsupported)
690{
691 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
692 input.GetDataType(),
693 &FalseInputFuncF16<>,
694 &TrueFunc<>,
695 &FalseFuncU8<>) &&
696 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
697 output.GetDataType(),
698 &TrueFunc<>,
699 &FalseOutputFuncF32<>,
700 &FalseFuncU8<>));
701}
702
narpra0132b90462018-09-13 11:07:48 +0100703bool IsMeanSupportedRef(const TensorInfo& input,
704 const TensorInfo& output,
705 const MeanDescriptor& descriptor,
706 std::string* reasonIfUnsupported)
707{
narpra011e4c31d2018-09-28 11:07:51 +0100708 ignore_unused(output);
709 ignore_unused(descriptor);
710 return IsSupportedForDataTypeRef(reasonIfUnsupported,
711 input.GetDataType(),
712 &TrueFunc<>,
713 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100714}
715
Nina Drozd661dfa72018-10-02 11:14:17 +0100716bool IsPadSupportedRef(const TensorInfo& input,
717 const TensorInfo& output,
718 const PadDescriptor& descriptor,
719 std::string* reasonIfUnsupported)
720{
721 ignore_unused(output);
722 ignore_unused(descriptor);
723 return false;
724}
725
arovir011c7c81b2018-10-08 11:34:28 +0100726} // namespace armnn