blob: 7677971c6b28a3d1f213e00afac8a0dae984643c [file] [log] [blame]
David Becke97c6e02018-10-03 13:09:28 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <armnn/DescriptorsFwd.hpp>
David Beck5eec11d2018-10-04 15:43:17 +01008#include <armnn/Optional.hpp>
David Becke97c6e02018-10-03 13:09:28 +01009#include <vector>
David Beck5eec11d2018-10-04 15:43:17 +010010#include <cctype>
David Beck3e9e1152018-10-17 14:17:50 +010011#include <memory>
David Becke97c6e02018-10-03 13:09:28 +010012
13namespace armnn
14{
15
16class TensorInfo;
17
18class ILayerSupport
19{
20protected:
21 ILayerSupport() {}
22 virtual ~ILayerSupport() {}
23
24public:
25 virtual bool IsActivationSupported(const TensorInfo& input,
26 const TensorInfo& output,
27 const ActivationDescriptor& descriptor,
arovir01d6c10ed2018-10-05 15:46:51 +010028 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010029
30 virtual bool IsAdditionSupported(const TensorInfo& input0,
31 const TensorInfo& input1,
32 const TensorInfo& output,
arovir01d6c10ed2018-10-05 15:46:51 +010033 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010034
35 virtual bool IsBatchNormalizationSupported(const TensorInfo& input,
36 const TensorInfo& output,
37 const TensorInfo& mean,
38 const TensorInfo& var,
39 const TensorInfo& beta,
40 const TensorInfo& gamma,
41 const BatchNormalizationDescriptor& descriptor,
arovir01d6c10ed2018-10-05 15:46:51 +010042 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010043
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000044 virtual bool IsBatchToSpaceNdSupported(const TensorInfo& input,
45 const TensorInfo& output,
46 const BatchToSpaceNdDescriptor& descriptor,
47 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
48
David Becke97c6e02018-10-03 13:09:28 +010049 virtual bool IsConstantSupported(const TensorInfo& output,
arovir01d6c10ed2018-10-05 15:46:51 +010050 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010051
52 virtual bool IsConvertFp16ToFp32Supported(const TensorInfo& input,
53 const TensorInfo& output,
arovir01d6c10ed2018-10-05 15:46:51 +010054 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010055
56 virtual bool IsConvertFp32ToFp16Supported(const TensorInfo& input,
57 const TensorInfo& output,
arovir01d6c10ed2018-10-05 15:46:51 +010058 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010059
60 virtual bool IsConvolution2dSupported(const TensorInfo& input,
61 const TensorInfo& output,
62 const Convolution2dDescriptor& descriptor,
63 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +010064 const Optional<TensorInfo>& biases,
arovir01d6c10ed2018-10-05 15:46:51 +010065 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010066
67 virtual bool IsDepthwiseConvolutionSupported(const TensorInfo& input,
68 const TensorInfo& output,
69 const DepthwiseConvolution2dDescriptor& descriptor,
70 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +010071 const Optional<TensorInfo>& biases,
arovir01d6c10ed2018-10-05 15:46:51 +010072 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010073
74 virtual bool IsDivisionSupported(const TensorInfo& input0,
75 const TensorInfo& input1,
76 const TensorInfo& output,
arovir01d6c10ed2018-10-05 15:46:51 +010077 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010078
arovir01537a0b62018-10-08 12:01:04 +010079 virtual bool IsFakeQuantizationSupported(const TensorInfo& input,
80 const FakeQuantizationDescriptor& descriptor,
81 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010082
arovir01537a0b62018-10-08 12:01:04 +010083 virtual bool IsFloorSupported(const TensorInfo& input,
84 const TensorInfo& output,
arovir01d6c10ed2018-10-05 15:46:51 +010085 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010086
87 virtual bool IsFullyConnectedSupported(const TensorInfo& input,
88 const TensorInfo& output,
89 const TensorInfo& weights,
90 const TensorInfo& biases,
91 const FullyConnectedDescriptor& descriptor,
arovir01d6c10ed2018-10-05 15:46:51 +010092 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +010093
arovir01537a0b62018-10-08 12:01:04 +010094 virtual bool IsInputSupported(const TensorInfo& input,
95 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
96
David Becke97c6e02018-10-03 13:09:28 +010097 virtual bool IsL2NormalizationSupported(const TensorInfo& input,
98 const TensorInfo& output,
99 const L2NormalizationDescriptor& descriptor,
arovir01d6c10ed2018-10-05 15:46:51 +0100100 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100101
102 virtual bool IsLstmSupported(const TensorInfo& input,
103 const TensorInfo& outputStateIn,
104 const TensorInfo& cellStateIn,
105 const TensorInfo& scratchBuffer,
106 const TensorInfo& outputStateOut,
107 const TensorInfo& cellStateOut,
108 const TensorInfo& output,
109 const LstmDescriptor& descriptor,
110 const TensorInfo& inputToForgetWeights,
111 const TensorInfo& inputToCellWeights,
112 const TensorInfo& inputToOutputWeights,
113 const TensorInfo& recurrentToForgetWeights,
114 const TensorInfo& recurrentToCellWeights,
115 const TensorInfo& recurrentToOutputWeights,
116 const TensorInfo& forgetGateBias,
117 const TensorInfo& cellBias,
118 const TensorInfo& outputGateBias,
119 const TensorInfo* inputToInputWeights,
120 const TensorInfo* recurrentToInputWeights,
121 const TensorInfo* cellToInputWeights,
122 const TensorInfo* inputGateBias,
123 const TensorInfo* projectionWeights,
124 const TensorInfo* projectionBias,
125 const TensorInfo* cellToForgetWeights,
126 const TensorInfo* cellToOutputWeights,
arovir01d6c10ed2018-10-05 15:46:51 +0100127 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100128
arovir01537a0b62018-10-08 12:01:04 +0100129 virtual bool IsMeanSupported(const TensorInfo& input,
130 const TensorInfo& output,
131 const MeanDescriptor& descriptor,
132 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
133
David Becke97c6e02018-10-03 13:09:28 +0100134 virtual bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000135 const TensorInfo& output,
David Becke97c6e02018-10-03 13:09:28 +0100136 const OriginsDescriptor& descriptor,
arovir01d6c10ed2018-10-05 15:46:51 +0100137 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100138
139 virtual bool IsMultiplicationSupported(const TensorInfo& input0,
140 const TensorInfo& input1,
141 const TensorInfo& output,
arovir01d6c10ed2018-10-05 15:46:51 +0100142 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100143
144 virtual bool IsNormalizationSupported(const TensorInfo& input,
145 const TensorInfo& output,
146 const NormalizationDescriptor& descriptor,
arovir01d6c10ed2018-10-05 15:46:51 +0100147 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100148
149 virtual bool IsOutputSupported(const TensorInfo& output,
arovir01d6c10ed2018-10-05 15:46:51 +0100150 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100151
arovir01537a0b62018-10-08 12:01:04 +0100152 virtual bool IsPadSupported(const TensorInfo& input,
153 const TensorInfo& output,
154 const PadDescriptor& descriptor,
155 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
156
David Becke97c6e02018-10-03 13:09:28 +0100157 virtual bool IsPermuteSupported(const TensorInfo& input,
158 const TensorInfo& output,
159 const PermuteDescriptor& descriptor,
arovir01d6c10ed2018-10-05 15:46:51 +0100160 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100161
162 virtual bool IsPooling2dSupported(const TensorInfo& input,
163 const TensorInfo& output,
164 const Pooling2dDescriptor& descriptor,
arovir01d6c10ed2018-10-05 15:46:51 +0100165 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100166
arovir01537a0b62018-10-08 12:01:04 +0100167 virtual bool IsReshapeSupported(const TensorInfo& input,
168 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
169
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000170 virtual bool IsSpaceToBatchNdSupported(const TensorInfo& input,
171 const TensorInfo& output,
172 const SpaceToBatchNdDescriptor& descriptor,
173 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
174
David Becke97c6e02018-10-03 13:09:28 +0100175 virtual bool IsResizeBilinearSupported(const TensorInfo& input,
arovir01d6c10ed2018-10-05 15:46:51 +0100176 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100177
178 virtual bool IsSoftmaxSupported(const TensorInfo& input,
179 const TensorInfo& output,
180 const SoftmaxDescriptor& descriptor,
arovir01d6c10ed2018-10-05 15:46:51 +0100181 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100182
183 virtual bool IsSplitterSupported(const TensorInfo& input,
184 const ViewsDescriptor& descriptor,
arovir01d6c10ed2018-10-05 15:46:51 +0100185 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100186
Conor Kennedy430b5d82018-11-14 15:28:28 +0000187 virtual bool IsStridedSliceSupported(const TensorInfo& input,
188 const TensorInfo& output,
189 const StridedSliceDescriptor& descriptor,
190 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
191
arovir01537a0b62018-10-08 12:01:04 +0100192 virtual bool IsSubtractionSupported(const TensorInfo& input0,
193 const TensorInfo& input1,
194 const TensorInfo& output,
195 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
David Becke97c6e02018-10-03 13:09:28 +0100196}; // class ILayerSupport
197
David Beck3e9e1152018-10-17 14:17:50 +0100198using ILayerSupportSharedPtr = std::shared_ptr<ILayerSupport>;
199
David Becke97c6e02018-10-03 13:09:28 +0100200} // namespace armnn