blob: 9aaa140caa8823e430819a6a5a31abe4705bdfad [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_TENSOR_OPS_H
17#define OPS_TENSOR_OPS_H
18
19#include "graph_node.h"
20#include "quant_util.h"
21
22using namespace tosa;
23
24namespace TosaReference
25{
26
27template <int Rank, DType Dtype>
28class OpArgMax : public GraphNode
29{
30public:
31 OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
32 virtual ~OpArgMax();
33
34 virtual int checkTensorAttributes();
35 virtual int eval();
36
37 using InEigenType = typename GetEigenType<Dtype>::type;
38 using OutEigenType = typename GetEigenType<DType_INT32>::type;
39 using TIn = Eigen::Tensor<InEigenType, Rank>;
40 using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
41
42protected:
43 TosaAxisAttribute* attribute;
44 TosaReference::TensorTemplate<TIn>* input;
45 TosaReference::TensorTemplate<TOut>* output;
46};
47
48template <DType Dtype>
49class OpAvgPool2d : public GraphNode
50{
51public:
52 OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
53 virtual ~OpAvgPool2d();
54
55 virtual int checkTensorAttributes();
56 virtual int eval();
57
58 static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
59 using InEigenType = typename GetEigenType<Dtype>::type;
60 using AccEigenType = typename GetEigenType<AccDtype>::type;
61 using OutEigenType = typename GetEigenType<Dtype>::type;
62 using TIn = Eigen::Tensor<InEigenType, 4>;
63 using TOut = Eigen::Tensor<OutEigenType, 4>;
64
65 static constexpr int64_t QMin = GetQMin<Dtype>::value;
66 static constexpr int64_t QMax = GetQMax<Dtype>::value;
67
68protected:
69 TosaReference::TensorTemplate<TIn>* in;
70 TosaReference::TensorTemplate<TOut>* out;
71 tosa::TosaPool2dAttribute* attribute;
72 tosa::TosaUnaryQuantInfo* qinfo;
73
74protected:
75 // return a 1D [N] tensor that describes a how many valid elements covered in the input space
76 ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride);
77};
78
79template <DType InDtype, DType WeightDtype>
80class OpConv2d : public GraphNode
81{
82public:
83 OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
84 virtual ~OpConv2d();
85
86 virtual int checkTensorAttributes() final;
87 virtual int eval() final;
88
89 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
90
91 using InEigenType = typename GetEigenType<InDtype>::type;
92 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
93 using AccEigenType = typename GetEigenType<AccDtype>::type;
94 using TIn = Eigen::Tensor<InEigenType, 4>;
95 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
96 using TBias = Eigen::Tensor<AccEigenType, 1>;
97 using TAcc = Eigen::Tensor<AccEigenType, 4>;
98
99 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
100 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
101
102protected:
103 TosaReference::TensorTemplate<TIn>* input;
104 TosaReference::TensorTemplate<TWeight>* weight;
105 TosaReference::TensorTemplate<TBias>* bias;
106 TosaReference::TensorTemplate<TAcc>* output;
107 tosa::TosaConv2dAttribute* attribute;
108 tosa::TosaConvQuantInfo* qinfo;
109};
110
111template <DType InDtype, DType WeightDtype>
112class OpDepthwiseConv2d : public GraphNode
113{
114public:
115 OpDepthwiseConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
116 virtual ~OpDepthwiseConv2d();
117
118 virtual int checkTensorAttributes() final;
119 virtual int eval() final;
120
121 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
122
123 using InEigenType = typename GetEigenType<InDtype>::type;
124 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
125 using AccEigenType = typename GetEigenType<AccDtype>::type;
126 using TIn = Eigen::Tensor<InEigenType, 4>;
127 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
128 using TBias = Eigen::Tensor<AccEigenType, 1>;
129 using TAcc = Eigen::Tensor<AccEigenType, 4>;
130
131 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
132 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
133
134protected:
135 TosaReference::TensorTemplate<TIn>* input;
136 TosaReference::TensorTemplate<TWeight>* weight;
137 TosaReference::TensorTemplate<TBias>* bias;
138 TosaReference::TensorTemplate<TAcc>* output;
139 tosa::TosaConv2dAttribute* attribute;
140 tosa::TosaConvQuantInfo* qinfo;
141};
142
143template <DType InDtype, DType WeightDtype>
144class OpFullyConnected : public GraphNode
145{
146public:
147 OpFullyConnected(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
148 virtual ~OpFullyConnected();
149
150 virtual int checkTensorAttributes() final;
151 virtual int eval() final;
152
153 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
154 using InEigenType = typename GetEigenType<InDtype>::type;
155 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
156 using AccEigenType = typename GetEigenType<AccDtype>::type;
157 using TIn = Eigen::Tensor<InEigenType, 2>;
158 using TWeight = Eigen::Tensor<WeightEigenType, 2>;
159 using TBias = Eigen::Tensor<AccEigenType, 1>;
160 using TAcc = Eigen::Tensor<AccEigenType, 2>;
161
162 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
163 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
164
165protected:
166 TosaReference::TensorTemplate<TIn>* input;
167 TosaReference::TensorTemplate<TWeight>* weight;
168 TosaReference::TensorTemplate<TBias>* bias;
169 TosaReference::TensorTemplate<TAcc>* output;
170 tosa::TosaConvQuantInfo* qinfo;
171};
172
173template <DType Dtype>
174class OpMatMul : public GraphNode
175{
176public:
177 OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
178 virtual ~OpMatMul();
179
180 virtual int checkTensorAttributes() final;
181 virtual int eval() final;
182
183 static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
184 using InEigenType = typename GetEigenType<Dtype>::type;
185 using AccEigenType = typename GetEigenType<AccDtype>::type;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700186 using TIn = Eigen::Tensor<InEigenType, 3>;
187 using TAcc = Eigen::Tensor<AccEigenType, 3>;
188 using TInRank2 = Eigen::Tensor<InEigenType, 2>;
189 using TAccRank2 = Eigen::Tensor<AccEigenType, 2>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700190 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
191 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
192
193protected:
194 TosaReference::TensorTemplate<TIn>* a;
195 TosaReference::TensorTemplate<TIn>* b;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700196 TosaReference::TensorTemplate<TAcc>* output;
197 int64_t N;
198 int64_t H;
199 int64_t W;
200 int64_t C;
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 tosa::TosaMatMulQuantInfo* qinfo;
202};
203
204template <DType Dtype>
205class OpMaxPool2d : public GraphNode
206{
207public:
208 OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
209 virtual ~OpMaxPool2d();
210
211 virtual int checkTensorAttributes();
212 virtual int eval();
213
214 using InEigenType = typename GetEigenType<Dtype>::type;
215 using OutEigenType = typename GetEigenType<Dtype>::type;
216 using TIn = Eigen::Tensor<InEigenType, 4>;
217 using TOut = Eigen::Tensor<OutEigenType, 4>;
218
219protected:
220 TosaReference::TensorTemplate<TIn>* in;
221 TosaReference::TensorTemplate<TOut>* out;
222 tosa::TosaPool2dAttribute* attribute;
223};
224
225template <DType InDtype, DType WeightDtype>
226class OpTransposeConv2d : public GraphNode
227{
228public:
229 OpTransposeConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
230 virtual ~OpTransposeConv2d();
231
232 virtual int checkTensorAttributes() final;
233 virtual int eval() final;
234
235 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
236
237 using InEigenType = typename GetEigenType<InDtype>::type;
238 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
239 using AccEigenType = typename GetEigenType<AccDtype>::type;
240 using TIn = Eigen::Tensor<InEigenType, 4>;
241 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
242 using TBias = Eigen::Tensor<AccEigenType, 1>;
243 using TAcc = Eigen::Tensor<AccEigenType, 4>;
244
245 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
246 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
247
248protected:
249 TosaReference::TensorTemplate<TIn>* input;
250 TosaReference::TensorTemplate<TWeight>* weight;
251 TosaReference::TensorTemplate<TBias>* bias;
252 TosaReference::TensorTemplate<TAcc>* output;
253 TosaTransposeConv2dAttribute* attribute;
254 TosaConvQuantInfo* qinfo;
255};
256
257}; // namespace TosaReference
258
259#endif