blob: 24eadeb84870e461faf868d9793877aca00e08aa [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_TENSOR_OPS_H
17#define OPS_TENSOR_OPS_H
18
19#include "graph_node.h"
20#include "quant_util.h"
21
22using namespace tosa;
23
24namespace TosaReference
25{
26
27template <int Rank, DType Dtype>
28class OpArgMax : public GraphNode
29{
30public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000031 OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070032 virtual ~OpArgMax();
33
34 virtual int checkTensorAttributes();
35 virtual int eval();
36
37 using InEigenType = typename GetEigenType<Dtype>::type;
38 using OutEigenType = typename GetEigenType<DType_INT32>::type;
39 using TIn = Eigen::Tensor<InEigenType, Rank>;
40 using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
41
42protected:
43 TosaAxisAttribute* attribute;
44 TosaReference::TensorTemplate<TIn>* input;
45 TosaReference::TensorTemplate<TOut>* output;
46};
47
48template <DType Dtype>
49class OpAvgPool2d : public GraphNode
50{
51public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000052 OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070053 virtual ~OpAvgPool2d();
54
55 virtual int checkTensorAttributes();
56 virtual int eval();
57
58 static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
59 using InEigenType = typename GetEigenType<Dtype>::type;
60 using AccEigenType = typename GetEigenType<AccDtype>::type;
61 using OutEigenType = typename GetEigenType<Dtype>::type;
62 using TIn = Eigen::Tensor<InEigenType, 4>;
63 using TOut = Eigen::Tensor<OutEigenType, 4>;
64
65 static constexpr int64_t QMin = GetQMin<Dtype>::value;
66 static constexpr int64_t QMax = GetQMax<Dtype>::value;
67
68protected:
69 TosaReference::TensorTemplate<TIn>* in;
70 TosaReference::TensorTemplate<TOut>* out;
Kevin Cheng93a16282021-08-31 16:14:03 -070071 tosa::TosaPoolAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -070072
73protected:
74 // return a 1D [N] tensor that describes a how many valid elements covered in the input space
Eric Kunze830add42022-01-25 22:56:46 -080075 ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
Eric Kunzee5e26762020-10-13 16:11:07 -070076};
77
78template <DType InDtype, DType WeightDtype>
79class OpConv2d : public GraphNode
80{
81public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000082 OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070083 virtual ~OpConv2d();
84
85 virtual int checkTensorAttributes() final;
86 virtual int eval() final;
87
88 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
89
90 using InEigenType = typename GetEigenType<InDtype>::type;
91 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
92 using AccEigenType = typename GetEigenType<AccDtype>::type;
93 using TIn = Eigen::Tensor<InEigenType, 4>;
94 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
95 using TBias = Eigen::Tensor<AccEigenType, 1>;
96 using TAcc = Eigen::Tensor<AccEigenType, 4>;
97
98 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
99 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
100
101protected:
102 TosaReference::TensorTemplate<TIn>* input;
103 TosaReference::TensorTemplate<TWeight>* weight;
104 TosaReference::TensorTemplate<TBias>* bias;
105 TosaReference::TensorTemplate<TAcc>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700106 tosa::TosaConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700107};
108
109template <DType InDtype, DType WeightDtype>
Kevin Cheng1533b852021-09-01 12:51:58 -0700110class OpConv3d : public GraphNode
111{
112public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000113 OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Kevin Cheng1533b852021-09-01 12:51:58 -0700114 virtual ~OpConv3d();
115
116 virtual int checkTensorAttributes() final;
117 virtual int eval() final;
118
119 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
120
121 using InEigenType = typename GetEigenType<InDtype>::type;
122 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
123 using AccEigenType = typename GetEigenType<AccDtype>::type;
124 using TIn = Eigen::Tensor<InEigenType, 5>;
125 using TWeight = Eigen::Tensor<WeightEigenType, 5>;
126 using TBias = Eigen::Tensor<AccEigenType, 1>;
127 using TAcc = Eigen::Tensor<AccEigenType, 5>;
128
129 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
130 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
131
132protected:
133 TosaReference::TensorTemplate<TIn>* input;
134 TosaReference::TensorTemplate<TWeight>* weight;
135 TosaReference::TensorTemplate<TBias>* bias;
136 TosaReference::TensorTemplate<TAcc>* output;
137 tosa::TosaConvAttribute* attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700138};
139
140template <DType InDtype, DType WeightDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700141class OpDepthwiseConv2d : public GraphNode
142{
143public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000144 OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700145 virtual ~OpDepthwiseConv2d();
146
147 virtual int checkTensorAttributes() final;
148 virtual int eval() final;
149
150 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
151
152 using InEigenType = typename GetEigenType<InDtype>::type;
153 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
154 using AccEigenType = typename GetEigenType<AccDtype>::type;
155 using TIn = Eigen::Tensor<InEigenType, 4>;
156 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
157 using TBias = Eigen::Tensor<AccEigenType, 1>;
158 using TAcc = Eigen::Tensor<AccEigenType, 4>;
159
160 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
161 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
162
163protected:
164 TosaReference::TensorTemplate<TIn>* input;
165 TosaReference::TensorTemplate<TWeight>* weight;
166 TosaReference::TensorTemplate<TBias>* bias;
167 TosaReference::TensorTemplate<TAcc>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700168 tosa::TosaConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700169};
170
171template <DType InDtype, DType WeightDtype>
172class OpFullyConnected : public GraphNode
173{
174public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000175 OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700176 virtual ~OpFullyConnected();
177
178 virtual int checkTensorAttributes() final;
179 virtual int eval() final;
180
181 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
182 using InEigenType = typename GetEigenType<InDtype>::type;
183 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
184 using AccEigenType = typename GetEigenType<AccDtype>::type;
185 using TIn = Eigen::Tensor<InEigenType, 2>;
186 using TWeight = Eigen::Tensor<WeightEigenType, 2>;
187 using TBias = Eigen::Tensor<AccEigenType, 1>;
188 using TAcc = Eigen::Tensor<AccEigenType, 2>;
189
190 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
191 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
192
193protected:
194 TosaReference::TensorTemplate<TIn>* input;
195 TosaReference::TensorTemplate<TWeight>* weight;
196 TosaReference::TensorTemplate<TBias>* bias;
197 TosaReference::TensorTemplate<TAcc>* output;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000198
199 tosa::TosaFullyConnectedAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700200};
201
202template <DType Dtype>
203class OpMatMul : public GraphNode
204{
205public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000206 OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 virtual ~OpMatMul();
208
209 virtual int checkTensorAttributes() final;
210 virtual int eval() final;
211
212 static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
213 using InEigenType = typename GetEigenType<Dtype>::type;
214 using AccEigenType = typename GetEigenType<AccDtype>::type;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700215 using TIn = Eigen::Tensor<InEigenType, 3>;
216 using TAcc = Eigen::Tensor<AccEigenType, 3>;
217 using TInRank2 = Eigen::Tensor<InEigenType, 2>;
218 using TAccRank2 = Eigen::Tensor<AccEigenType, 2>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
220 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
221
222protected:
223 TosaReference::TensorTemplate<TIn>* a;
224 TosaReference::TensorTemplate<TIn>* b;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700225 TosaReference::TensorTemplate<TAcc>* output;
226 int64_t N;
227 int64_t H;
228 int64_t W;
229 int64_t C;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000230
231 tosa::TosaMatMulAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700232};
233
234template <DType Dtype>
235class OpMaxPool2d : public GraphNode
236{
237public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000238 OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700239 virtual ~OpMaxPool2d();
240
241 virtual int checkTensorAttributes();
242 virtual int eval();
243
244 using InEigenType = typename GetEigenType<Dtype>::type;
245 using OutEigenType = typename GetEigenType<Dtype>::type;
246 using TIn = Eigen::Tensor<InEigenType, 4>;
247 using TOut = Eigen::Tensor<OutEigenType, 4>;
248
249protected:
250 TosaReference::TensorTemplate<TIn>* in;
251 TosaReference::TensorTemplate<TOut>* out;
Kevin Cheng93a16282021-08-31 16:14:03 -0700252 tosa::TosaPoolAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700253};
254
255template <DType InDtype, DType WeightDtype>
256class OpTransposeConv2d : public GraphNode
257{
258public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000259 OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700260 virtual ~OpTransposeConv2d();
261
262 virtual int checkTensorAttributes() final;
263 virtual int eval() final;
264
265 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
266
267 using InEigenType = typename GetEigenType<InDtype>::type;
268 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
269 using AccEigenType = typename GetEigenType<AccDtype>::type;
270 using TIn = Eigen::Tensor<InEigenType, 4>;
271 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
272 using TBias = Eigen::Tensor<AccEigenType, 1>;
273 using TAcc = Eigen::Tensor<AccEigenType, 4>;
274
275 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
276 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
277
278protected:
279 TosaReference::TensorTemplate<TIn>* input;
280 TosaReference::TensorTemplate<TWeight>* weight;
281 TosaReference::TensorTemplate<TBias>* bias;
282 TosaReference::TensorTemplate<TAcc>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700283 TosaTransposeConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700284};
285
286}; // namespace TosaReference
287
288#endif