blob: 26ce84bc2ffb17e6eb668983f5ac9fe1c3229179 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_TENSOR_OPS_H
17#define OPS_TENSOR_OPS_H
18
19#include "graph_node.h"
20#include "quant_util.h"
21
22using namespace tosa;
23
24namespace TosaReference
25{
26
27template <int Rank, DType Dtype>
28class OpArgMax : public GraphNode
29{
30public:
31 OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
32 virtual ~OpArgMax();
33
34 virtual int checkTensorAttributes();
35 virtual int eval();
36
37 using InEigenType = typename GetEigenType<Dtype>::type;
38 using OutEigenType = typename GetEigenType<DType_INT32>::type;
39 using TIn = Eigen::Tensor<InEigenType, Rank>;
40 using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
41
42protected:
43 TosaAxisAttribute* attribute;
44 TosaReference::TensorTemplate<TIn>* input;
45 TosaReference::TensorTemplate<TOut>* output;
46};
47
48template <DType Dtype>
49class OpAvgPool2d : public GraphNode
50{
51public:
52 OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
53 virtual ~OpAvgPool2d();
54
55 virtual int checkTensorAttributes();
56 virtual int eval();
57
58 static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
59 using InEigenType = typename GetEigenType<Dtype>::type;
60 using AccEigenType = typename GetEigenType<AccDtype>::type;
61 using OutEigenType = typename GetEigenType<Dtype>::type;
62 using TIn = Eigen::Tensor<InEigenType, 4>;
63 using TOut = Eigen::Tensor<OutEigenType, 4>;
64
65 static constexpr int64_t QMin = GetQMin<Dtype>::value;
66 static constexpr int64_t QMax = GetQMax<Dtype>::value;
67
68protected:
69 TosaReference::TensorTemplate<TIn>* in;
70 TosaReference::TensorTemplate<TOut>* out;
71 tosa::TosaPool2dAttribute* attribute;
72 tosa::TosaUnaryQuantInfo* qinfo;
73
74protected:
75 // return a 1D [N] tensor that describes a how many valid elements covered in the input space
76 ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride);
77};
78
79template <DType InDtype, DType WeightDtype>
80class OpConv2d : public GraphNode
81{
82public:
83 OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
84 virtual ~OpConv2d();
85
86 virtual int checkTensorAttributes() final;
87 virtual int eval() final;
88
89 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
90
91 using InEigenType = typename GetEigenType<InDtype>::type;
92 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
93 using AccEigenType = typename GetEigenType<AccDtype>::type;
94 using TIn = Eigen::Tensor<InEigenType, 4>;
95 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
96 using TBias = Eigen::Tensor<AccEigenType, 1>;
97 using TAcc = Eigen::Tensor<AccEigenType, 4>;
98
99 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
100 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
101
102protected:
103 TosaReference::TensorTemplate<TIn>* input;
104 TosaReference::TensorTemplate<TWeight>* weight;
105 TosaReference::TensorTemplate<TBias>* bias;
106 TosaReference::TensorTemplate<TAcc>* output;
107 tosa::TosaConv2dAttribute* attribute;
108 tosa::TosaConvQuantInfo* qinfo;
109};
110
111template <DType InDtype, DType WeightDtype>
112class OpDepthwiseConv2d : public GraphNode
113{
114public:
115 OpDepthwiseConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
116 virtual ~OpDepthwiseConv2d();
117
118 virtual int checkTensorAttributes() final;
119 virtual int eval() final;
120
121 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
122
123 using InEigenType = typename GetEigenType<InDtype>::type;
124 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
125 using AccEigenType = typename GetEigenType<AccDtype>::type;
126 using TIn = Eigen::Tensor<InEigenType, 4>;
127 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
128 using TBias = Eigen::Tensor<AccEigenType, 1>;
129 using TAcc = Eigen::Tensor<AccEigenType, 4>;
130
131 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
132 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
133
134protected:
135 TosaReference::TensorTemplate<TIn>* input;
136 TosaReference::TensorTemplate<TWeight>* weight;
137 TosaReference::TensorTemplate<TBias>* bias;
138 TosaReference::TensorTemplate<TAcc>* output;
139 tosa::TosaConv2dAttribute* attribute;
140 tosa::TosaConvQuantInfo* qinfo;
141};
142
143template <DType InDtype, DType WeightDtype>
144class OpFullyConnected : public GraphNode
145{
146public:
147 OpFullyConnected(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
148 virtual ~OpFullyConnected();
149
150 virtual int checkTensorAttributes() final;
151 virtual int eval() final;
152
153 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
154 using InEigenType = typename GetEigenType<InDtype>::type;
155 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
156 using AccEigenType = typename GetEigenType<AccDtype>::type;
157 using TIn = Eigen::Tensor<InEigenType, 2>;
158 using TWeight = Eigen::Tensor<WeightEigenType, 2>;
159 using TBias = Eigen::Tensor<AccEigenType, 1>;
160 using TAcc = Eigen::Tensor<AccEigenType, 2>;
161
162 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
163 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
164
165protected:
166 TosaReference::TensorTemplate<TIn>* input;
167 TosaReference::TensorTemplate<TWeight>* weight;
168 TosaReference::TensorTemplate<TBias>* bias;
169 TosaReference::TensorTemplate<TAcc>* output;
170 tosa::TosaConvQuantInfo* qinfo;
171};
172
173template <DType Dtype>
174class OpMatMul : public GraphNode
175{
176public:
177 OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
178 virtual ~OpMatMul();
179
180 virtual int checkTensorAttributes() final;
181 virtual int eval() final;
182
183 static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
184 using InEigenType = typename GetEigenType<Dtype>::type;
185 using AccEigenType = typename GetEigenType<AccDtype>::type;
186 using TIn = Eigen::Tensor<InEigenType, 2>;
187 using TAcc = Eigen::Tensor<AccEigenType, 2>;
188 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
189 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
190
191protected:
192 TosaReference::TensorTemplate<TIn>* a;
193 TosaReference::TensorTemplate<TIn>* b;
194 TosaReference::TensorTemplate<TAcc>* c;
195 tosa::TosaMatMulQuantInfo* qinfo;
196};
197
198template <DType Dtype>
199class OpMaxPool2d : public GraphNode
200{
201public:
202 OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
203 virtual ~OpMaxPool2d();
204
205 virtual int checkTensorAttributes();
206 virtual int eval();
207
208 using InEigenType = typename GetEigenType<Dtype>::type;
209 using OutEigenType = typename GetEigenType<Dtype>::type;
210 using TIn = Eigen::Tensor<InEigenType, 4>;
211 using TOut = Eigen::Tensor<OutEigenType, 4>;
212
213protected:
214 TosaReference::TensorTemplate<TIn>* in;
215 TosaReference::TensorTemplate<TOut>* out;
216 tosa::TosaPool2dAttribute* attribute;
217};
218
219template <DType InDtype, DType WeightDtype>
220class OpTransposeConv2d : public GraphNode
221{
222public:
223 OpTransposeConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
224 virtual ~OpTransposeConv2d();
225
226 virtual int checkTensorAttributes() final;
227 virtual int eval() final;
228
229 static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
230
231 using InEigenType = typename GetEigenType<InDtype>::type;
232 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
233 using AccEigenType = typename GetEigenType<AccDtype>::type;
234 using TIn = Eigen::Tensor<InEigenType, 4>;
235 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
236 using TBias = Eigen::Tensor<AccEigenType, 1>;
237 using TAcc = Eigen::Tensor<AccEigenType, 4>;
238
239 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
240 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
241
242protected:
243 TosaReference::TensorTemplate<TIn>* input;
244 TosaReference::TensorTemplate<TWeight>* weight;
245 TosaReference::TensorTemplate<TBias>* bias;
246 TosaReference::TensorTemplate<TAcc>* output;
247 TosaTransposeConv2dAttribute* attribute;
248 TosaConvQuantInfo* qinfo;
249};
250
251}; // namespace TosaReference
252
253#endif