blob: 2e6554869b32ee878b4552b0d28d763e428a640d [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Hutton261b7b62023-01-10 14:50:31 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_TENSOR_OPS_H
17#define OPS_TENSOR_OPS_H
18
19#include "graph_node.h"
20#include "quant_util.h"
21
22using namespace tosa;
23
24namespace TosaReference
25{
26
Tai Lya4d748b2023-03-28 22:06:56 +000027template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070028class OpArgMax : public GraphNode
29{
30public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000031 OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070032 virtual ~OpArgMax();
33
34 virtual int checkTensorAttributes();
35 virtual int eval();
36
37 using InEigenType = typename GetEigenType<Dtype>::type;
Tai Lya4d748b2023-03-28 22:06:56 +000038 using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_INT32>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -070039 using TIn = Eigen::Tensor<InEigenType, Rank>;
40 using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
41
42protected:
43 TosaAxisAttribute* attribute;
44 TosaReference::TensorTemplate<TIn>* input;
45 TosaReference::TensorTemplate<TOut>* output;
46};
47
Tai Lya4d748b2023-03-28 22:06:56 +000048template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070049class OpAvgPool2d : public GraphNode
50{
51public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000052 OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070053 virtual ~OpAvgPool2d();
54
55 virtual int checkTensorAttributes();
56 virtual int eval();
57
Jerry Ge9c9c8da2023-07-19 23:08:16 +000058 using InEigenType = typename GetEigenType<Dtype>::type;
59 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
60 using OutEigenType = typename GetEigenType<Dtype>::type;
61 using TIn = Eigen::Tensor<InEigenType, 4>;
62 using TOut = Eigen::Tensor<OutEigenType, 4>;
Eric Kunzee5e26762020-10-13 16:11:07 -070063
64 static constexpr int64_t QMin = GetQMin<Dtype>::value;
65 static constexpr int64_t QMax = GetQMax<Dtype>::value;
66
67protected:
68 TosaReference::TensorTemplate<TIn>* in;
69 TosaReference::TensorTemplate<TOut>* out;
Kevin Cheng93a16282021-08-31 16:14:03 -070070 tosa::TosaPoolAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72protected:
73 // return a 1D [N] tensor that describes a how many valid elements covered in the input space
Jerry Ge9c9c8da2023-07-19 23:08:16 +000074 ETensor1<int32_t> calculate_div_map_1d(
75 int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
Eric Kunzee5e26762020-10-13 16:11:07 -070076};
77
Tai Lyf36f2562024-03-14 16:21:29 +000078template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070079class OpConv2d : public GraphNode
80{
81public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000082 OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070083 virtual ~OpConv2d();
84
85 virtual int checkTensorAttributes() final;
86 virtual int eval() final;
87
Eric Kunzee5e26762020-10-13 16:11:07 -070088 using InEigenType = typename GetEigenType<InDtype>::type;
89 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
Tai Lyf36f2562024-03-14 16:21:29 +000090 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
James Wardd34b3fc2023-01-18 14:51:25 +000091 using OutEigenType = typename GetEigenType<OutDtype>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -070092 using TIn = Eigen::Tensor<InEigenType, 4>;
93 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
James Ward8b390432022-08-12 20:48:56 +010094 using TBias = Eigen::Tensor<OutEigenType, 1>;
95 using TOut = Eigen::Tensor<OutEigenType, 4>;
Eric Kunzee5e26762020-10-13 16:11:07 -070096
James Wardd34b3fc2023-01-18 14:51:25 +000097 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
98 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Eric Kunzee5e26762020-10-13 16:11:07 -070099
100protected:
101 TosaReference::TensorTemplate<TIn>* input;
102 TosaReference::TensorTemplate<TWeight>* weight;
103 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100104 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700105 tosa::TosaConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700106};
107
Tai Lyf36f2562024-03-14 16:21:29 +0000108template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
Kevin Cheng1533b852021-09-01 12:51:58 -0700109class OpConv3d : public GraphNode
110{
111public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000112 OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Kevin Cheng1533b852021-09-01 12:51:58 -0700113 virtual ~OpConv3d();
114
115 virtual int checkTensorAttributes() final;
116 virtual int eval() final;
117
Kevin Cheng1533b852021-09-01 12:51:58 -0700118 using InEigenType = typename GetEigenType<InDtype>::type;
119 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
Tai Lyf36f2562024-03-14 16:21:29 +0000120 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
James Wardd34b3fc2023-01-18 14:51:25 +0000121 using OutEigenType = typename GetEigenType<OutDtype>::type;
Kevin Cheng1533b852021-09-01 12:51:58 -0700122 using TIn = Eigen::Tensor<InEigenType, 5>;
123 using TWeight = Eigen::Tensor<WeightEigenType, 5>;
James Ward8b390432022-08-12 20:48:56 +0100124 using TBias = Eigen::Tensor<OutEigenType, 1>;
125 using TOut = Eigen::Tensor<OutEigenType, 5>;
Kevin Cheng1533b852021-09-01 12:51:58 -0700126
James Wardd34b3fc2023-01-18 14:51:25 +0000127 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
128 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Kevin Cheng1533b852021-09-01 12:51:58 -0700129
130protected:
131 TosaReference::TensorTemplate<TIn>* input;
132 TosaReference::TensorTemplate<TWeight>* weight;
133 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100134 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng1533b852021-09-01 12:51:58 -0700135 tosa::TosaConvAttribute* attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700136};
137
Tai Lyf36f2562024-03-14 16:21:29 +0000138template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700139class OpDepthwiseConv2d : public GraphNode
140{
141public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000142 OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 virtual ~OpDepthwiseConv2d();
144
145 virtual int checkTensorAttributes() final;
146 virtual int eval() final;
147
Eric Kunzee5e26762020-10-13 16:11:07 -0700148 using InEigenType = typename GetEigenType<InDtype>::type;
149 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
Tai Lyf36f2562024-03-14 16:21:29 +0000150 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
James Wardd34b3fc2023-01-18 14:51:25 +0000151 using OutEigenType = typename GetEigenType<OutDtype>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -0700152 using TIn = Eigen::Tensor<InEigenType, 4>;
153 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
James Ward8b390432022-08-12 20:48:56 +0100154 using TBias = Eigen::Tensor<OutEigenType, 1>;
155 using TOut = Eigen::Tensor<OutEigenType, 4>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700156
James Wardd34b3fc2023-01-18 14:51:25 +0000157 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
158 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Eric Kunzee5e26762020-10-13 16:11:07 -0700159
160protected:
161 TosaReference::TensorTemplate<TIn>* input;
162 TosaReference::TensorTemplate<TWeight>* weight;
163 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100164 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700165 tosa::TosaConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700166};
167
Tai Lya4d748b2023-03-28 22:06:56 +0000168template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700169class OpFullyConnected : public GraphNode
170{
171public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000172 OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 virtual ~OpFullyConnected();
174
175 virtual int checkTensorAttributes() final;
176 virtual int eval() final;
177
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000178 using InEigenType = typename GetEigenType<InDtype>::type;
179 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
180 using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
181 using OutEigenType = typename GetEigenType<OutDtype>::type;
182 using TIn = Eigen::Tensor<InEigenType, 2>;
183 using TWeight = Eigen::Tensor<WeightEigenType, 2>;
184 using TBias = Eigen::Tensor<OutEigenType, 1>;
185 using TOut = Eigen::Tensor<OutEigenType, 2>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700186
James Wardd34b3fc2023-01-18 14:51:25 +0000187 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
188 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Eric Kunzee5e26762020-10-13 16:11:07 -0700189
190protected:
191 TosaReference::TensorTemplate<TIn>* input;
192 TosaReference::TensorTemplate<TWeight>* weight;
193 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100194 TosaReference::TensorTemplate<TOut>* output;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000195
196 tosa::TosaFullyConnectedAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700197};
198
Tai Lya4d748b2023-03-28 22:06:56 +0000199template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700200class OpMatMul : public GraphNode
201{
202public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000203 OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700204 virtual ~OpMatMul();
205
206 virtual int checkTensorAttributes() final;
207 virtual int eval() final;
208
Eric Kunzee5e26762020-10-13 16:11:07 -0700209 using InEigenType = typename GetEigenType<Dtype>::type;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000210 using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
James Wardd34b3fc2023-01-18 14:51:25 +0000211 using OutEigenType = typename GetEigenType<OutDtype>::type;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700212 using TIn = Eigen::Tensor<InEigenType, 3>;
James Ward8b390432022-08-12 20:48:56 +0100213 using TOut = Eigen::Tensor<OutEigenType, 3>;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700214 using TInRank2 = Eigen::Tensor<InEigenType, 2>;
215 using TAccRank2 = Eigen::Tensor<AccEigenType, 2>;
James Wardd34b3fc2023-01-18 14:51:25 +0000216 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
217 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Eric Kunzee5e26762020-10-13 16:11:07 -0700218
219protected:
220 TosaReference::TensorTemplate<TIn>* a;
221 TosaReference::TensorTemplate<TIn>* b;
James Ward8b390432022-08-12 20:48:56 +0100222 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng2d60f002021-06-09 14:18:32 -0700223 int64_t N;
224 int64_t H;
225 int64_t W;
226 int64_t C;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227
228 tosa::TosaMatMulAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700229};
230
Tai Lya4d748b2023-03-28 22:06:56 +0000231template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700232class OpMaxPool2d : public GraphNode
233{
234public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000235 OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700236 virtual ~OpMaxPool2d();
237
238 virtual int checkTensorAttributes();
239 virtual int eval();
240
241 using InEigenType = typename GetEigenType<Dtype>::type;
242 using OutEigenType = typename GetEigenType<Dtype>::type;
243 using TIn = Eigen::Tensor<InEigenType, 4>;
244 using TOut = Eigen::Tensor<OutEigenType, 4>;
245
246protected:
247 TosaReference::TensorTemplate<TIn>* in;
248 TosaReference::TensorTemplate<TOut>* out;
Kevin Cheng93a16282021-08-31 16:14:03 -0700249 tosa::TosaPoolAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700250};
251
Tai Lya4d748b2023-03-28 22:06:56 +0000252template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +0000253class OpFFT2d : public GraphNode
254{
255public:
256 OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
257 virtual ~OpFFT2d();
258
259 virtual int checkTensorAttributes() final;
260 virtual int eval() final;
261
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000262 using InEigenType = typename GetEigenType<Dtype>::type;
263 using OutEigenType = typename GetEigenType<Dtype>::type;
264 using TIn = Eigen::Tensor<InEigenType, 3>;
265 using TOut = Eigen::Tensor<OutEigenType, 3>;
Luke Hutton57287132023-02-06 14:54:18 +0000266
267protected:
268 TosaReference::TensorTemplate<TIn>* in_real;
269 TosaReference::TensorTemplate<TIn>* in_imag;
270 TosaReference::TensorTemplate<TOut>* out_real;
271 TosaReference::TensorTemplate<TOut>* out_imag;
272 tosa::TosaFFTAttribute* attribute;
273};
274
Tai Lya4d748b2023-03-28 22:06:56 +0000275template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +0000276class OpRFFT2d : public GraphNode
277{
278public:
279 OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
280 virtual ~OpRFFT2d();
281
282 virtual int checkTensorAttributes() final;
283 virtual int eval() final;
284
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000285 using InEigenType = typename GetEigenType<Dtype>::type;
286 using OutEigenType = typename GetEigenType<Dtype>::type;
287 using TIn = Eigen::Tensor<InEigenType, 3>;
288 using TOut = Eigen::Tensor<OutEigenType, 3>;
Luke Hutton261b7b62023-01-10 14:50:31 +0000289
290protected:
291 TosaReference::TensorTemplate<TIn>* in;
292 TosaReference::TensorTemplate<TOut>* out_real;
293 TosaReference::TensorTemplate<TOut>* out_imag;
Tai Lyfd8fde82023-11-13 20:18:14 +0000294 tosa::TosaRFFTAttribute* attribute;
Luke Hutton261b7b62023-01-10 14:50:31 +0000295};
296
Tai Lyf36f2562024-03-14 16:21:29 +0000297template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700298class OpTransposeConv2d : public GraphNode
299{
300public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000301 OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700302 virtual ~OpTransposeConv2d();
303
304 virtual int checkTensorAttributes() final;
305 virtual int eval() final;
306
Eric Kunzee5e26762020-10-13 16:11:07 -0700307 using InEigenType = typename GetEigenType<InDtype>::type;
308 using WeightEigenType = typename GetEigenType<WeightDtype>::type;
Tai Lyf36f2562024-03-14 16:21:29 +0000309 using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
James Wardd34b3fc2023-01-18 14:51:25 +0000310 using OutEigenType = typename GetEigenType<OutDtype>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -0700311 using TIn = Eigen::Tensor<InEigenType, 4>;
312 using TWeight = Eigen::Tensor<WeightEigenType, 4>;
James Ward8b390432022-08-12 20:48:56 +0100313 using TBias = Eigen::Tensor<OutEigenType, 1>;
314 using TOut = Eigen::Tensor<OutEigenType, 4>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700315
James Wardd34b3fc2023-01-18 14:51:25 +0000316 static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
317 static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
Eric Kunzee5e26762020-10-13 16:11:07 -0700318
319protected:
320 TosaReference::TensorTemplate<TIn>* input;
321 TosaReference::TensorTemplate<TWeight>* weight;
322 TosaReference::TensorTemplate<TBias>* bias;
James Ward8b390432022-08-12 20:48:56 +0100323 TosaReference::TensorTemplate<TOut>* output;
Kevin Cheng93a16282021-08-31 16:14:03 -0700324 TosaTransposeConvAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700325};
326
327}; // namespace TosaReference
328
329#endif