blob: fd6dd25159579a0b0615ecfa9f2313ae7bdc1b76 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_TENSOR_OPS_H
17#define OPS_TENSOR_OPS_H
18
19#include "graph_node.h"
20#include "quant_util.h"
21
22using namespace tosa;
23
24namespace TosaReference
25{
26
27template <int Rank, DType Dtype>
28class OpArgMax : public GraphNode
29{
30public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000031 OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070032 virtual ~OpArgMax();
33
34 virtual int checkTensorAttributes();
35 virtual int eval();
36
37 using InEigenType = typename GetEigenType<Dtype>::type;
38 using OutEigenType = typename GetEigenType<DType_INT32>::type;
39 using TIn = Eigen::Tensor<InEigenType, Rank>;
40 using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
41
42protected:
43 TosaAxisAttribute* attribute;
44 TosaReference::TensorTemplate<TIn>* input;
45 TosaReference::TensorTemplate<TOut>* output;
46};
47
James Ward8b390432022-08-12 20:48:56 +010048template <DType Dtype, DType AccDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070049class OpAvgPool2d : public GraphNode
50{
51public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000052 OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070053 virtual ~OpAvgPool2d();
54
55 virtual int checkTensorAttributes();
56 virtual int eval();
57
Eric Kunzee5e26762020-10-13 16:11:07 -070058 using InEigenType = typename GetEigenType<Dtype>::type;
James Ward8b390432022-08-12 20:48:56 +010059 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
Eric Kunzee5e26762020-10-13 16:11:07 -070060 using OutEigenType = typename GetEigenType<Dtype>::type;
61 using TIn = Eigen::Tensor<InEigenType, 4>;
62 using TOut = Eigen::Tensor<OutEigenType, 4>;
63
64 static constexpr int64_t QMin = GetQMin<Dtype>::value;
65 static constexpr int64_t QMax = GetQMax<Dtype>::value;
66
67protected:
68 TosaReference::TensorTemplate<TIn>* in;
69 TosaReference::TensorTemplate<TOut>* out;
Kevin Cheng93a16282021-08-31 16:14:03 -070070 tosa::TosaPoolAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72protected:
73 // return a 1D [N] tensor that describes a how many valid elements covered in the input space
Eric Kunze830add42022-01-25 22:56:46 -080074 ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
Eric Kunzee5e26762020-10-13 16:11:07 -070075};
76
James Ward8b390432022-08-12 20:48:56 +010077template <DType InDtype, DType WeightDtype, DType AccDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070078class OpConv2d : public GraphNode
79{
80public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000081 OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070082 virtual ~OpConv2d();
83
84 virtual int checkTensorAttributes() final;
85 virtual int eval() final;
86
Eric Kunzee5e26762020-10-13 16:11:07 -070087 using InEigenType = typename GetEigenType<InDtype>::type;
88 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
James Ward8b390432022-08-12 20:48:56 +010089 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
90 using OutEigenType = typename GetEigenType<AccDtype>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -070091 using TIn = Eigen::Tensor<InEigenType, 4>;
92 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
James Ward8b390432022-08-12 20:48:56 +010093 using TBias = Eigen::Tensor<OutEigenType, 1>;
94 using TOut = Eigen::Tensor<OutEigenType, 4>;
Eric Kunzee5e26762020-10-13 16:11:07 -070095
96 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
97 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
98
99protected:
100 TosaReference::TensorTemplate<TIn>* input;
101 TosaReference::TensorTemplate<TWeight>* weight;
102 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100103 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700104 tosa::TosaConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700105};
106
James Ward8b390432022-08-12 20:48:56 +0100107template <DType InDtype, DType WeightDtype, DType AccDtype>
Kevin Cheng1533b852021-09-01 12:51:58 -0700108class OpConv3d : public GraphNode
109{
110public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000111 OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Kevin Cheng1533b852021-09-01 12:51:58 -0700112 virtual ~OpConv3d();
113
114 virtual int checkTensorAttributes() final;
115 virtual int eval() final;
116
Kevin Cheng1533b852021-09-01 12:51:58 -0700117 using InEigenType = typename GetEigenType<InDtype>::type;
118 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
James Ward8b390432022-08-12 20:48:56 +0100119 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
120 using OutEigenType = typename GetEigenType<AccDtype>::type;
Kevin Cheng1533b852021-09-01 12:51:58 -0700121 using TIn = Eigen::Tensor<InEigenType, 5>;
122 using TWeight = Eigen::Tensor<WeightEigenType, 5>;
James Ward8b390432022-08-12 20:48:56 +0100123 using TBias = Eigen::Tensor<OutEigenType, 1>;
124 using TOut = Eigen::Tensor<OutEigenType, 5>;
Kevin Cheng1533b852021-09-01 12:51:58 -0700125
126 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
127 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
128
129protected:
130 TosaReference::TensorTemplate<TIn>* input;
131 TosaReference::TensorTemplate<TWeight>* weight;
132 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100133 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng1533b852021-09-01 12:51:58 -0700134 tosa::TosaConvAttribute* attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700135};
136
James Ward8b390432022-08-12 20:48:56 +0100137template <DType InDtype, DType WeightDtype, DType AccDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700138class OpDepthwiseConv2d : public GraphNode
139{
140public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000141 OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700142 virtual ~OpDepthwiseConv2d();
143
144 virtual int checkTensorAttributes() final;
145 virtual int eval() final;
146
Eric Kunzee5e26762020-10-13 16:11:07 -0700147 using InEigenType = typename GetEigenType<InDtype>::type;
148 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
James Ward8b390432022-08-12 20:48:56 +0100149 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
150 using OutEigenType = typename GetEigenType<AccDtype>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -0700151 using TIn = Eigen::Tensor<InEigenType, 4>;
152 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
James Ward8b390432022-08-12 20:48:56 +0100153 using TBias = Eigen::Tensor<OutEigenType, 1>;
154 using TOut = Eigen::Tensor<OutEigenType, 4>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700155
156 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
157 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
158
159protected:
160 TosaReference::TensorTemplate<TIn>* input;
161 TosaReference::TensorTemplate<TWeight>* weight;
162 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100163 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700164 tosa::TosaConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700165};
166
James Ward8b390432022-08-12 20:48:56 +0100167template <DType InDtype, DType WeightDtype, DType AccDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700168class OpFullyConnected : public GraphNode
169{
170public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000171 OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700172 virtual ~OpFullyConnected();
173
174 virtual int checkTensorAttributes() final;
175 virtual int eval() final;
176
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 using InEigenType = typename GetEigenType<InDtype>::type;
178 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
James Ward8b390432022-08-12 20:48:56 +0100179 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
180 using OutEigenType = typename GetEigenType<AccDtype>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 using TIn = Eigen::Tensor<InEigenType, 2>;
182 using TWeight = Eigen::Tensor<WeightEigenType, 2>;
James Ward8b390432022-08-12 20:48:56 +0100183 using TBias = Eigen::Tensor<OutEigenType, 1>;
184 using TOut = Eigen::Tensor<OutEigenType, 2>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700185
186 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
187 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
188
189protected:
190 TosaReference::TensorTemplate<TIn>* input;
191 TosaReference::TensorTemplate<TWeight>* weight;
192 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100193 TosaReference::TensorTemplate<TOut>* output;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000194
195 tosa::TosaFullyConnectedAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700196};
197
James Ward8b390432022-08-12 20:48:56 +0100198template <DType Dtype, DType AccDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700199class OpMatMul : public GraphNode
200{
201public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000202 OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700203 virtual ~OpMatMul();
204
205 virtual int checkTensorAttributes() final;
206 virtual int eval() final;
207
Eric Kunzee5e26762020-10-13 16:11:07 -0700208 using InEigenType = typename GetEigenType<Dtype>::type;
James Ward8b390432022-08-12 20:48:56 +0100209 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
210 using OutEigenType = typename GetEigenType<AccDtype>::type;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700211 using TIn = Eigen::Tensor<InEigenType, 3>;
James Ward8b390432022-08-12 20:48:56 +0100212 using TOut = Eigen::Tensor<OutEigenType, 3>;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700213 using TInRank2 = Eigen::Tensor<InEigenType, 2>;
214 using TAccRank2 = Eigen::Tensor<AccEigenType, 2>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700215 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
216 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
217
218protected:
219 TosaReference::TensorTemplate<TIn>* a;
220 TosaReference::TensorTemplate<TIn>* b;
James Ward8b390432022-08-12 20:48:56 +0100221 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700222 int64_t N;
223 int64_t H;
224 int64_t W;
225 int64_t C;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000226
227 tosa::TosaMatMulAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700228};
229
230template <DType Dtype>
231class OpMaxPool2d : public GraphNode
232{
233public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000234 OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700235 virtual ~OpMaxPool2d();
236
237 virtual int checkTensorAttributes();
238 virtual int eval();
239
240 using InEigenType = typename GetEigenType<Dtype>::type;
241 using OutEigenType = typename GetEigenType<Dtype>::type;
242 using TIn = Eigen::Tensor<InEigenType, 4>;
243 using TOut = Eigen::Tensor<OutEigenType, 4>;
244
245protected:
246 TosaReference::TensorTemplate<TIn>* in;
247 TosaReference::TensorTemplate<TOut>* out;
Kevin Cheng93a16282021-08-31 16:14:03 -0700248 tosa::TosaPoolAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700249};
250
James Ward8b390432022-08-12 20:48:56 +0100251template <DType InDtype, DType WeightDtype, DType AccDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700252class OpTransposeConv2d : public GraphNode
253{
254public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000255 OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700256 virtual ~OpTransposeConv2d();
257
258 virtual int checkTensorAttributes() final;
259 virtual int eval() final;
260
Eric Kunzee5e26762020-10-13 16:11:07 -0700261 using InEigenType = typename GetEigenType<InDtype>::type;
262 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
James Ward8b390432022-08-12 20:48:56 +0100263 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
264 using OutEigenType = typename GetEigenType<AccDtype>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 using TIn = Eigen::Tensor<InEigenType, 4>;
266 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
James Ward8b390432022-08-12 20:48:56 +0100267 using TBias = Eigen::Tensor<OutEigenType, 1>;
268 using TOut = Eigen::Tensor<OutEigenType, 4>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700269
270 static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
271 static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
272
273protected:
274 TosaReference::TensorTemplate<TIn>* input;
275 TosaReference::TensorTemplate<TWeight>* weight;
276 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100277 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700278 TosaTransposeConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700279};
280
281}; // namespace TosaReference
282
283#endif