blob: 05b1ca10cf733dfe9319e26f01e088879668402f [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_TENSOR_OPS_H
17#define OPS_TENSOR_OPS_H
18
19#include "graph_node.h"
20#include "quant_util.h"
21
22using namespace tosa;
23
24namespace TosaReference
25{
26
27template <int Rank, DType Dtype>
28class OpArgMax : public GraphNode
29{
30public:
Kevin Chengacb550f2021-06-29 15:32:19 -070031 OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070032 virtual ~OpArgMax();
33
34 virtual int checkTensorAttributes();
35 virtual int eval();
36
37 using InEigenType = typename GetEigenType<Dtype>::type;
38 using OutEigenType = typename GetEigenType<DType_INT32>::type;
39 using TIn = Eigen::Tensor<InEigenType, Rank>;
40 using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
41
42protected:
43 TosaAxisAttribute* attribute;
44 TosaReference::TensorTemplate<TIn>* input;
45 TosaReference::TensorTemplate<TOut>* output;
46};
47
48template <DType Dtype>
49class OpAvgPool2d : public GraphNode
50{
51public:
Kevin Chengacb550f2021-06-29 15:32:19 -070052 OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070053 virtual ~OpAvgPool2d();
54
55 virtual int checkTensorAttributes();
56 virtual int eval();
57
58 static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
59 using InEigenType = typename GetEigenType<Dtype>::type;
60 using AccEigenType = typename GetEigenType<AccDtype>::type;
61 using OutEigenType = typename GetEigenType<Dtype>::type;
62 using TIn = Eigen::Tensor<InEigenType, 4>;
63 using TOut = Eigen::Tensor<OutEigenType, 4>;
64
65 static constexpr int64_t QMin = GetQMin<Dtype>::value;
66 static constexpr int64_t QMax = GetQMax<Dtype>::value;
67
68protected:
69 TosaReference::TensorTemplate<TIn>* in;
70 TosaReference::TensorTemplate<TOut>* out;
Kevin Cheng93a16282021-08-31 16:14:03 -070071 tosa::TosaPoolAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -070072 tosa::TosaUnaryQuantInfo* qinfo;
73
74protected:
75 // return a 1D [N] tensor that describes a how many valid elements covered in the input space
Eric Kunze830add42022-01-25 22:56:46 -080076 ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
Eric Kunzee5e26762020-10-13 16:11:07 -070077};
78
79template <DType InDtype, DType WeightDtype>
80class OpConv2d : public GraphNode
81{
82public:
Kevin Chengacb550f2021-06-29 15:32:19 -070083 OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070084 virtual ~OpConv2d();
85
86 virtual int checkTensorAttributes() final;
87 virtual int eval() final;
88
89 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
90
91 using InEigenType = typename GetEigenType<InDtype>::type;
92 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
93 using AccEigenType = typename GetEigenType<AccDtype>::type;
94 using TIn = Eigen::Tensor<InEigenType, 4>;
95 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
96 using TBias = Eigen::Tensor<AccEigenType, 1>;
97 using TAcc = Eigen::Tensor<AccEigenType, 4>;
98
99 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
100 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
101
102protected:
103 TosaReference::TensorTemplate<TIn>* input;
104 TosaReference::TensorTemplate<TWeight>* weight;
105 TosaReference::TensorTemplate<TBias>* bias;
106 TosaReference::TensorTemplate<TAcc>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700107 tosa::TosaConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 tosa::TosaConvQuantInfo* qinfo;
109};
110
111template <DType InDtype, DType WeightDtype>
Kevin Cheng1533b852021-09-01 12:51:58 -0700112class OpConv3d : public GraphNode
113{
114public:
115 OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
116 virtual ~OpConv3d();
117
118 virtual int checkTensorAttributes() final;
119 virtual int eval() final;
120
121 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
122
123 using InEigenType = typename GetEigenType<InDtype>::type;
124 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
125 using AccEigenType = typename GetEigenType<AccDtype>::type;
126 using TIn = Eigen::Tensor<InEigenType, 5>;
127 using TWeight = Eigen::Tensor<WeightEigenType, 5>;
128 using TBias = Eigen::Tensor<AccEigenType, 1>;
129 using TAcc = Eigen::Tensor<AccEigenType, 5>;
130
131 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
132 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
133
134protected:
135 TosaReference::TensorTemplate<TIn>* input;
136 TosaReference::TensorTemplate<TWeight>* weight;
137 TosaReference::TensorTemplate<TBias>* bias;
138 TosaReference::TensorTemplate<TAcc>* output;
139 tosa::TosaConvAttribute* attribute;
140 tosa::TosaConvQuantInfo* qinfo;
141};
142
143template <DType InDtype, DType WeightDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700144class OpDepthwiseConv2d : public GraphNode
145{
146public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700147 OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700148 virtual ~OpDepthwiseConv2d();
149
150 virtual int checkTensorAttributes() final;
151 virtual int eval() final;
152
153 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
154
155 using InEigenType = typename GetEigenType<InDtype>::type;
156 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
157 using AccEigenType = typename GetEigenType<AccDtype>::type;
158 using TIn = Eigen::Tensor<InEigenType, 4>;
159 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
160 using TBias = Eigen::Tensor<AccEigenType, 1>;
161 using TAcc = Eigen::Tensor<AccEigenType, 4>;
162
163 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
164 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
165
166protected:
167 TosaReference::TensorTemplate<TIn>* input;
168 TosaReference::TensorTemplate<TWeight>* weight;
169 TosaReference::TensorTemplate<TBias>* bias;
170 TosaReference::TensorTemplate<TAcc>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700171 tosa::TosaConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700172 tosa::TosaConvQuantInfo* qinfo;
173};
174
175template <DType InDtype, DType WeightDtype>
176class OpFullyConnected : public GraphNode
177{
178public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700179 OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700180 virtual ~OpFullyConnected();
181
182 virtual int checkTensorAttributes() final;
183 virtual int eval() final;
184
185 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
186 using InEigenType = typename GetEigenType<InDtype>::type;
187 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
188 using AccEigenType = typename GetEigenType<AccDtype>::type;
189 using TIn = Eigen::Tensor<InEigenType, 2>;
190 using TWeight = Eigen::Tensor<WeightEigenType, 2>;
191 using TBias = Eigen::Tensor<AccEigenType, 1>;
192 using TAcc = Eigen::Tensor<AccEigenType, 2>;
193
194 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
195 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
196
197protected:
198 TosaReference::TensorTemplate<TIn>* input;
199 TosaReference::TensorTemplate<TWeight>* weight;
200 TosaReference::TensorTemplate<TBias>* bias;
201 TosaReference::TensorTemplate<TAcc>* output;
202 tosa::TosaConvQuantInfo* qinfo;
203};
204
205template <DType Dtype>
206class OpMatMul : public GraphNode
207{
208public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700209 OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700210 virtual ~OpMatMul();
211
212 virtual int checkTensorAttributes() final;
213 virtual int eval() final;
214
215 static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
216 using InEigenType = typename GetEigenType<Dtype>::type;
217 using AccEigenType = typename GetEigenType<AccDtype>::type;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700218 using TIn = Eigen::Tensor<InEigenType, 3>;
219 using TAcc = Eigen::Tensor<AccEigenType, 3>;
220 using TInRank2 = Eigen::Tensor<InEigenType, 2>;
221 using TAccRank2 = Eigen::Tensor<AccEigenType, 2>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700222 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
223 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
224
225protected:
226 TosaReference::TensorTemplate<TIn>* a;
227 TosaReference::TensorTemplate<TIn>* b;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700228 TosaReference::TensorTemplate<TAcc>* output;
229 int64_t N;
230 int64_t H;
231 int64_t W;
232 int64_t C;
Eric Kunzee5e26762020-10-13 16:11:07 -0700233 tosa::TosaMatMulQuantInfo* qinfo;
234};
235
236template <DType Dtype>
237class OpMaxPool2d : public GraphNode
238{
239public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700240 OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700241 virtual ~OpMaxPool2d();
242
243 virtual int checkTensorAttributes();
244 virtual int eval();
245
246 using InEigenType = typename GetEigenType<Dtype>::type;
247 using OutEigenType = typename GetEigenType<Dtype>::type;
248 using TIn = Eigen::Tensor<InEigenType, 4>;
249 using TOut = Eigen::Tensor<OutEigenType, 4>;
250
251protected:
252 TosaReference::TensorTemplate<TIn>* in;
253 TosaReference::TensorTemplate<TOut>* out;
Kevin Cheng93a16282021-08-31 16:14:03 -0700254 tosa::TosaPoolAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700255};
256
257template <DType InDtype, DType WeightDtype>
258class OpTransposeConv2d : public GraphNode
259{
260public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700261 OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700262 virtual ~OpTransposeConv2d();
263
264 virtual int checkTensorAttributes() final;
265 virtual int eval() final;
266
267 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
268
269 using InEigenType = typename GetEigenType<InDtype>::type;
270 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
271 using AccEigenType = typename GetEigenType<AccDtype>::type;
272 using TIn = Eigen::Tensor<InEigenType, 4>;
273 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
274 using TBias = Eigen::Tensor<AccEigenType, 1>;
275 using TAcc = Eigen::Tensor<AccEigenType, 4>;
276
277 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
278 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
279
280protected:
281 TosaReference::TensorTemplate<TIn>* input;
282 TosaReference::TensorTemplate<TWeight>* weight;
283 TosaReference::TensorTemplate<TBias>* bias;
284 TosaReference::TensorTemplate<TAcc>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700285 TosaTransposeConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700286 TosaConvQuantInfo* qinfo;
287};
288
289}; // namespace TosaReference
290
291#endif