blob: 9341709c1bee9d98eecc86b48de9e8330416f660 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Tai Ly8690a082023-12-18 20:40:24 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_DATA_LAYOUT_H
17#define OPS_DATA_LAYOUT_H
18
19#include "graph_node.h"
20
21using namespace tosa;
22
23namespace TosaReference
24{
25
Tai Lya4d748b2023-03-28 22:06:56 +000026template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070027class OpConcat : public GraphNode
28{
29public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070031 virtual ~OpConcat();
32
33 virtual int checkTensorAttributes();
34 virtual int eval();
35
36 using InEigenType = typename GetEigenType<Dtype>::type;
37 using OutEigenType = typename GetEigenType<Dtype>::type;
38 using TIn = Eigen::Tensor<InEigenType, Rank>;
39 using TOut = Eigen::Tensor<OutEigenType, Rank>;
40
41protected:
42 Eigen::array<int, Rank> reverser;
Kevin Chengad15dfa2021-03-04 15:15:03 -080043 std::vector<TosaReference::TensorTemplate<TIn>*> ins;
Eric Kunzee5e26762020-10-13 16:11:07 -070044 TosaAxisAttribute* attribute;
45 TosaReference::TensorTemplate<TOut>* out;
46};
47
Tai Lya4d748b2023-03-28 22:06:56 +000048template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070049class OpPad : public GraphNode
50{
51public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000052 OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070053 virtual ~OpPad();
54 virtual int checkTensorAttributes();
55 virtual int eval();
56
57 using InEigenType = typename GetEigenType<Dtype>::type;
58 using OutEigenType = typename GetEigenType<Dtype>::type;
59 using TIn = Eigen::Tensor<InEigenType, Rank>;
60 using TOut = Eigen::Tensor<OutEigenType, Rank>;
61
62protected:
63 Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
64 TosaReference::TensorTemplate<TIn>* in;
65 TosaReference::TensorTemplate<TOut>* out;
Kevin Chengfe392ce2021-10-18 21:51:55 +000066 TosaPadAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -070067};
68
Won Jeona21b2e82023-08-10 10:33:01 +000069template <int Rank, TOSA_REF_TYPE Dtype>
70class OpDim : public GraphNode
71{
72public:
73 OpDim(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
74 virtual ~OpDim();
75
76 virtual int checkTensorAttributes();
77 virtual int eval();
78
79 using InEigenType = typename GetEigenType<Dtype>::type;
80 using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
81 using TIn = Eigen::Tensor<InEigenType, Rank>;
Tai Ly8690a082023-12-18 20:40:24 +000082 using TOut = Eigen::Tensor<OutEigenType, 1>;
Won Jeona21b2e82023-08-10 10:33:01 +000083
84protected:
85 TosaReference::TensorTemplate<TIn>* in;
86 TosaReference::TensorTemplate<TOut>* out;
87 TosaAxisAttribute* attribute;
88};
89
Tai Lya4d748b2023-03-28 22:06:56 +000090template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070091class OpReshape : public GraphNode
92{
93public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000094 OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070095 virtual ~OpReshape();
96
97 virtual int checkTensorAttributes();
98 virtual int eval();
99
100 using InEigenType = typename GetEigenType<Dtype>::type;
101 using OutEigenType = typename GetEigenType<Dtype>::type;
102 using TIn = Eigen::Tensor<InEigenType, InRank>;
103 using TOut = Eigen::Tensor<OutEigenType, OutRank>;
104
105protected:
106 Eigen::array<Eigen::Index, OutRank> array_shape;
107 Eigen::array<Eigen::Index, InRank> in_reverser;
108 Eigen::array<Eigen::Index, OutRank> out_reverser;
109 TosaReference::TensorTemplate<TIn>* in;
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 TosaReference::TensorTemplate<TOut>* out;
111};
112
Tai Lya4d748b2023-03-28 22:06:56 +0000113template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700114class OpReverse : public GraphNode
115{
116public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000117 OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700118 virtual ~OpReverse();
119
120 virtual int checkTensorAttributes();
121 virtual int eval();
122
123 using InEigenType = typename GetEigenType<Dtype>::type;
124 using OutEigenType = typename GetEigenType<Dtype>::type;
125 using TIn = Eigen::Tensor<InEigenType, Rank>;
126 using TOut = Eigen::Tensor<OutEigenType, Rank>;
127
128protected:
129 TosaAxisAttribute* attribute;
130 TosaReference::TensorTemplate<TIn>* in;
131 TosaReference::TensorTemplate<TOut>* out;
132 Eigen::array<bool, Rank> reverse_array;
133};
134
Tai Lya4d748b2023-03-28 22:06:56 +0000135template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700136class OpSlice : public GraphNode
137{
138public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000139 OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700140 virtual ~OpSlice();
141
142 virtual int checkTensorAttributes();
143 virtual int eval();
144
145 using InEigenType = typename GetEigenType<Dtype>::type;
146 using OutEigenType = typename GetEigenType<Dtype>::type;
147 using TIn = Eigen::Tensor<InEigenType, Rank>;
148 using TOut = Eigen::Tensor<OutEigenType, Rank>;
149
150protected:
151 TosaSliceAttribute* attribute;
152 Eigen::array<Eigen::Index, Rank> begin_array;
153 Eigen::array<Eigen::Index, Rank> size_array;
154 TosaReference::TensorTemplate<TIn>* in;
155 TosaReference::TensorTemplate<TOut>* out;
156};
157
Tai Lya4d748b2023-03-28 22:06:56 +0000158template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700159class OpTileBase : public GraphNode
160{
161public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000162 OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 virtual ~OpTileBase();
164
165 virtual int checkTensorAttributes();
166
Tai Ly8690a082023-12-18 20:40:24 +0000167 using InEigenType = typename GetEigenType<Dtype>::type;
168 using InEigenShapeType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
169 using OutEigenType = typename GetEigenType<Dtype>::type;
170 using TIn = Eigen::Tensor<InEigenType, Rank>;
171 using TInMultiples = Eigen::Tensor<InEigenShapeType, 1>;
172 using TOut = Eigen::Tensor<OutEigenType, Rank>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700173
174protected:
175 TosaTileAttribute* attribute;
176 TosaReference::TensorTemplate<TIn>* in;
Tai Ly8690a082023-12-18 20:40:24 +0000177 TosaReference::TensorTemplate<TInMultiples>* multiples;
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 TosaReference::TensorTemplate<TOut>* out;
179};
180
181// primary template for op tile
Tai Lya4d748b2023-03-28 22:06:56 +0000182template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700183class OpTile : public OpTileBase<Rank, Dtype>
184{
185public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000186 OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
187 : OpTileBase<Rank, Dtype>(sgt_, attribute_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700188 {}
189
190protected:
191 virtual int eval();
192};
193
194// partial specialization for specific rank
195#define DEF_OP_TILE_RANK(N) \
Tai Lya4d748b2023-03-28 22:06:56 +0000196 template <TOSA_REF_TYPE Dtype> \
Eric Kunzee5e26762020-10-13 16:11:07 -0700197 class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
198 { \
199 public: \
Tai Lya4d748b2023-03-28 22:06:56 +0000200 OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
201 : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \
Eric Kunzee5e26762020-10-13 16:11:07 -0700202 {} \
203 \
204 protected: \
205 virtual int eval(); \
206 };
207
208DEF_OP_TILE_RANK(1)
209DEF_OP_TILE_RANK(2)
210DEF_OP_TILE_RANK(3)
211DEF_OP_TILE_RANK(4)
Luke Huttona4e48ca2023-02-22 11:53:48 +0000212DEF_OP_TILE_RANK(5)
213DEF_OP_TILE_RANK(6)
Eric Kunzee5e26762020-10-13 16:11:07 -0700214
215#undef DEF_OP_TILE_RANK
216
Tai Lya4d748b2023-03-28 22:06:56 +0000217template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700218class OpTranspose : public GraphNode
219{
220public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000221 OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700222 virtual ~OpTranspose();
223
224 virtual int checkTensorAttributes();
225 virtual int eval();
226
227 using InEigenType = typename GetEigenType<Dtype>::type;
228 using OutEigenType = typename GetEigenType<Dtype>::type;
229 using TIn = Eigen::Tensor<InEigenType, Rank>;
230 using TOut = Eigen::Tensor<OutEigenType, Rank>;
231
232protected:
233 Eigen::array<int, Rank> perm_array;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000234 TosaTransposeAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700235 TosaReference::TensorTemplate<TIn>* in;
Eric Kunzee5e26762020-10-13 16:11:07 -0700236 TosaReference::TensorTemplate<TOut>* out;
237};
238}; // namespace TosaReference
239
240#endif