blob: 9ef4a58b62d943f44a58365cee0e1ee420695086 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Hutton261b7b62023-01-10 14:50:31 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_TENSOR_OPS_H
17#define OPS_TENSOR_OPS_H
18
19#include "graph_node.h"
20#include "quant_util.h"
21
22using namespace tosa;
23
24namespace TosaReference
25{
26
27template <int Rank, DType Dtype>
28class OpArgMax : public GraphNode
29{
30public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000031 OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070032 virtual ~OpArgMax();
33
34 virtual int checkTensorAttributes();
35 virtual int eval();
36
37 using InEigenType = typename GetEigenType<Dtype>::type;
38 using OutEigenType = typename GetEigenType<DType_INT32>::type;
39 using TIn = Eigen::Tensor<InEigenType, Rank>;
40 using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
41
42protected:
43 TosaAxisAttribute* attribute;
44 TosaReference::TensorTemplate<TIn>* input;
45 TosaReference::TensorTemplate<TOut>* output;
46};
47
James Ward8b390432022-08-12 20:48:56 +010048template <DType Dtype, DType AccDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070049class OpAvgPool2d : public GraphNode
50{
51public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000052 OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070053 virtual ~OpAvgPool2d();
54
55 virtual int checkTensorAttributes();
56 virtual int eval();
57
Eric Kunzee5e26762020-10-13 16:11:07 -070058 using InEigenType = typename GetEigenType<Dtype>::type;
James Ward8b390432022-08-12 20:48:56 +010059 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
Eric Kunzee5e26762020-10-13 16:11:07 -070060 using OutEigenType = typename GetEigenType<Dtype>::type;
61 using TIn = Eigen::Tensor<InEigenType, 4>;
62 using TOut = Eigen::Tensor<OutEigenType, 4>;
63
64 static constexpr int64_t QMin = GetQMin<Dtype>::value;
65 static constexpr int64_t QMax = GetQMax<Dtype>::value;
66
67protected:
68 TosaReference::TensorTemplate<TIn>* in;
69 TosaReference::TensorTemplate<TOut>* out;
Kevin Cheng93a16282021-08-31 16:14:03 -070070 tosa::TosaPoolAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72protected:
73 // return a 1D [N] tensor that describes a how many valid elements covered in the input space
Eric Kunze830add42022-01-25 22:56:46 -080074 ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
Eric Kunzee5e26762020-10-13 16:11:07 -070075};
76
James Wardd34b3fc2023-01-18 14:51:25 +000077template <DType InDtype, DType WeightDtype, DType OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070078class OpConv2d : public GraphNode
79{
80public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000081 OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070082 virtual ~OpConv2d();
83
84 virtual int checkTensorAttributes() final;
85 virtual int eval() final;
86
Eric Kunzee5e26762020-10-13 16:11:07 -070087 using InEigenType = typename GetEigenType<InDtype>::type;
88 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
James Wardd34b3fc2023-01-18 14:51:25 +000089 using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
90 using OutEigenType = typename GetEigenType<OutDtype>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -070091 using TIn = Eigen::Tensor<InEigenType, 4>;
92 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
James Ward8b390432022-08-12 20:48:56 +010093 using TBias = Eigen::Tensor<OutEigenType, 1>;
94 using TOut = Eigen::Tensor<OutEigenType, 4>;
Eric Kunzee5e26762020-10-13 16:11:07 -070095
James Wardd34b3fc2023-01-18 14:51:25 +000096 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
97 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99protected:
100 TosaReference::TensorTemplate<TIn>* input;
101 TosaReference::TensorTemplate<TWeight>* weight;
102 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100103 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700104 tosa::TosaConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700105};
106
James Wardd34b3fc2023-01-18 14:51:25 +0000107template <DType InDtype, DType WeightDtype, DType OutDtype>
Kevin Cheng1533b852021-09-01 12:51:58 -0700108class OpConv3d : public GraphNode
109{
110public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000111 OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Kevin Cheng1533b852021-09-01 12:51:58 -0700112 virtual ~OpConv3d();
113
114 virtual int checkTensorAttributes() final;
115 virtual int eval() final;
116
Kevin Cheng1533b852021-09-01 12:51:58 -0700117 using InEigenType = typename GetEigenType<InDtype>::type;
118 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
James Wardd34b3fc2023-01-18 14:51:25 +0000119 using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
120 using OutEigenType = typename GetEigenType<OutDtype>::type;
Kevin Cheng1533b852021-09-01 12:51:58 -0700121 using TIn = Eigen::Tensor<InEigenType, 5>;
122 using TWeight = Eigen::Tensor<WeightEigenType, 5>;
James Ward8b390432022-08-12 20:48:56 +0100123 using TBias = Eigen::Tensor<OutEigenType, 1>;
124 using TOut = Eigen::Tensor<OutEigenType, 5>;
Kevin Cheng1533b852021-09-01 12:51:58 -0700125
James Wardd34b3fc2023-01-18 14:51:25 +0000126 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
127 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Kevin Cheng1533b852021-09-01 12:51:58 -0700128
129protected:
130 TosaReference::TensorTemplate<TIn>* input;
131 TosaReference::TensorTemplate<TWeight>* weight;
132 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100133 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng1533b852021-09-01 12:51:58 -0700134 tosa::TosaConvAttribute* attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700135};
136
James Wardd34b3fc2023-01-18 14:51:25 +0000137template <DType InDtype, DType WeightDtype, DType OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700138class OpDepthwiseConv2d : public GraphNode
139{
140public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000141 OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700142 virtual ~OpDepthwiseConv2d();
143
144 virtual int checkTensorAttributes() final;
145 virtual int eval() final;
146
Eric Kunzee5e26762020-10-13 16:11:07 -0700147 using InEigenType = typename GetEigenType<InDtype>::type;
148 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
James Wardd34b3fc2023-01-18 14:51:25 +0000149 using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
150 using OutEigenType = typename GetEigenType<OutDtype>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -0700151 using TIn = Eigen::Tensor<InEigenType, 4>;
152 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
James Ward8b390432022-08-12 20:48:56 +0100153 using TBias = Eigen::Tensor<OutEigenType, 1>;
154 using TOut = Eigen::Tensor<OutEigenType, 4>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700155
James Wardd34b3fc2023-01-18 14:51:25 +0000156 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
157 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Eric Kunzee5e26762020-10-13 16:11:07 -0700158
159protected:
160 TosaReference::TensorTemplate<TIn>* input;
161 TosaReference::TensorTemplate<TWeight>* weight;
162 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100163 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700164 tosa::TosaConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700165};
166
James Wardd34b3fc2023-01-18 14:51:25 +0000167template <DType InDtype, DType WeightDtype, DType OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700168class OpFullyConnected : public GraphNode
169{
170public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000171 OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700172 virtual ~OpFullyConnected();
173
174 virtual int checkTensorAttributes() final;
175 virtual int eval() final;
176
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 using InEigenType = typename GetEigenType<InDtype>::type;
178 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
James Wardd34b3fc2023-01-18 14:51:25 +0000179 using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
180 using OutEigenType = typename GetEigenType<OutDtype>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 using TIn = Eigen::Tensor<InEigenType, 2>;
182 using TWeight = Eigen::Tensor<WeightEigenType, 2>;
James Ward8b390432022-08-12 20:48:56 +0100183 using TBias = Eigen::Tensor<OutEigenType, 1>;
184 using TOut = Eigen::Tensor<OutEigenType, 2>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700185
James Wardd34b3fc2023-01-18 14:51:25 +0000186 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
187 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Eric Kunzee5e26762020-10-13 16:11:07 -0700188
189protected:
190 TosaReference::TensorTemplate<TIn>* input;
191 TosaReference::TensorTemplate<TWeight>* weight;
192 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100193 TosaReference::TensorTemplate<TOut>* output;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000194
195 tosa::TosaFullyConnectedAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700196};
197
James Wardd34b3fc2023-01-18 14:51:25 +0000198template <DType Dtype, DType OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700199class OpMatMul : public GraphNode
200{
201public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000202 OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700203 virtual ~OpMatMul();
204
205 virtual int checkTensorAttributes() final;
206 virtual int eval() final;
207
Eric Kunzee5e26762020-10-13 16:11:07 -0700208 using InEigenType = typename GetEigenType<Dtype>::type;
James Wardd34b3fc2023-01-18 14:51:25 +0000209 using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
210 using OutEigenType = typename GetEigenType<OutDtype>::type;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700211 using TIn = Eigen::Tensor<InEigenType, 3>;
James Ward8b390432022-08-12 20:48:56 +0100212 using TOut = Eigen::Tensor<OutEigenType, 3>;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700213 using TInRank2 = Eigen::Tensor<InEigenType, 2>;
214 using TAccRank2 = Eigen::Tensor<AccEigenType, 2>;
James Wardd34b3fc2023-01-18 14:51:25 +0000215 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
216 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
218protected:
219 TosaReference::TensorTemplate<TIn>* a;
220 TosaReference::TensorTemplate<TIn>* b;
James Ward8b390432022-08-12 20:48:56 +0100221 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700222 int64_t N;
223 int64_t H;
224 int64_t W;
225 int64_t C;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000226
227 tosa::TosaMatMulAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700228};
229
230template <DType Dtype>
231class OpMaxPool2d : public GraphNode
232{
233public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000234 OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700235 virtual ~OpMaxPool2d();
236
237 virtual int checkTensorAttributes();
238 virtual int eval();
239
240 using InEigenType = typename GetEigenType<Dtype>::type;
241 using OutEigenType = typename GetEigenType<Dtype>::type;
242 using TIn = Eigen::Tensor<InEigenType, 4>;
243 using TOut = Eigen::Tensor<OutEigenType, 4>;
244
245protected:
246 TosaReference::TensorTemplate<TIn>* in;
247 TosaReference::TensorTemplate<TOut>* out;
Kevin Cheng93a16282021-08-31 16:14:03 -0700248 tosa::TosaPoolAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700249};
250
Luke Hutton261b7b62023-01-10 14:50:31 +0000251template <DType Dtype>
Luke Hutton57287132023-02-06 14:54:18 +0000252class OpFFT2d : public GraphNode
253{
254public:
255 OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
256 virtual ~OpFFT2d();
257
258 virtual int checkTensorAttributes() final;
259 virtual int eval() final;
260
261 using InEigenType = typename GetEigenType<Dtype>::type;
262 using OutEigenType = typename GetEigenType<Dtype>::type;
263 using TIn = Eigen::Tensor<InEigenType, 3>;
264 using TOut = Eigen::Tensor<OutEigenType, 3>;
265
266protected:
267 TosaReference::TensorTemplate<TIn>* in_real;
268 TosaReference::TensorTemplate<TIn>* in_imag;
269 TosaReference::TensorTemplate<TOut>* out_real;
270 TosaReference::TensorTemplate<TOut>* out_imag;
271 tosa::TosaFFTAttribute* attribute;
272};
273
274template <DType Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +0000275class OpRFFT2d : public GraphNode
276{
277public:
278 OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
279 virtual ~OpRFFT2d();
280
281 virtual int checkTensorAttributes() final;
282 virtual int eval() final;
283
284 using InEigenType = typename GetEigenType<Dtype>::type;
285 using OutEigenType = typename GetEigenType<Dtype>::type;
286 using TIn = Eigen::Tensor<InEigenType, 3>;
287 using TOut = Eigen::Tensor<OutEigenType, 3>;
288
289protected:
290 TosaReference::TensorTemplate<TIn>* in;
291 TosaReference::TensorTemplate<TOut>* out_real;
292 TosaReference::TensorTemplate<TOut>* out_imag;
293};
294
James Wardd34b3fc2023-01-18 14:51:25 +0000295template <DType InDtype, DType WeightDtype, DType OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700296class OpTransposeConv2d : public GraphNode
297{
298public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000299 OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700300 virtual ~OpTransposeConv2d();
301
302 virtual int checkTensorAttributes() final;
303 virtual int eval() final;
304
Eric Kunzee5e26762020-10-13 16:11:07 -0700305 using InEigenType = typename GetEigenType<InDtype>::type;
306 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
James Wardd34b3fc2023-01-18 14:51:25 +0000307 using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
308 using OutEigenType = typename GetEigenType<OutDtype>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -0700309 using TIn = Eigen::Tensor<InEigenType, 4>;
310 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
James Ward8b390432022-08-12 20:48:56 +0100311 using TBias = Eigen::Tensor<OutEigenType, 1>;
312 using TOut = Eigen::Tensor<OutEigenType, 4>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700313
James Wardd34b3fc2023-01-18 14:51:25 +0000314 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
315 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Eric Kunzee5e26762020-10-13 16:11:07 -0700316
317protected:
318 TosaReference::TensorTemplate<TIn>* input;
319 TosaReference::TensorTemplate<TWeight>* weight;
320 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100321 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700322 TosaTransposeConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700323};
324
325}; // namespace TosaReference
326
327#endif