blob: dee2ae0d1156f973907e7c97a40fb25b7f26feec [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Tai Ly8690a082023-12-18 20:40:24 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_DATA_LAYOUT_H
17#define OPS_DATA_LAYOUT_H
18
19#include "graph_node.h"
20
21using namespace tosa;
22
23namespace TosaReference
24{
25
Tai Lya4d748b2023-03-28 22:06:56 +000026template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070027class OpConcat : public GraphNode
28{
29public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070031 virtual ~OpConcat();
32
33 virtual int checkTensorAttributes();
34 virtual int eval();
35
36 using InEigenType = typename GetEigenType<Dtype>::type;
37 using OutEigenType = typename GetEigenType<Dtype>::type;
38 using TIn = Eigen::Tensor<InEigenType, Rank>;
39 using TOut = Eigen::Tensor<OutEigenType, Rank>;
40
41protected:
42 Eigen::array<int, Rank> reverser;
Kevin Chengad15dfa2021-03-04 15:15:03 -080043 std::vector<TosaReference::TensorTemplate<TIn>*> ins;
Eric Kunzee5e26762020-10-13 16:11:07 -070044 TosaAxisAttribute* attribute;
45 TosaReference::TensorTemplate<TOut>* out;
46};
47
Tai Lya4d748b2023-03-28 22:06:56 +000048template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070049class OpPad : public GraphNode
50{
51public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000052 OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070053 virtual ~OpPad();
54 virtual int checkTensorAttributes();
55 virtual int eval();
56
Tai Lye095da72024-01-25 22:00:18 +000057 using InEigenType = typename GetEigenType<Dtype>::type;
58 using InEigenShapeType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
59 using OutEigenType = typename GetEigenType<Dtype>::type;
60 using TIn = Eigen::Tensor<InEigenType, Rank>;
61 using TPadding = Eigen::Tensor<InEigenShapeType, 1>;
62 using TOut = Eigen::Tensor<OutEigenType, Rank>;
Eric Kunzee5e26762020-10-13 16:11:07 -070063
64protected:
65 Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
66 TosaReference::TensorTemplate<TIn>* in;
Tai Lye095da72024-01-25 22:00:18 +000067 TosaReference::TensorTemplate<TPadding>* padding;
Eric Kunzee5e26762020-10-13 16:11:07 -070068 TosaReference::TensorTemplate<TOut>* out;
Kevin Chengfe392ce2021-10-18 21:51:55 +000069 TosaPadAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -070070};
71
Won Jeona21b2e82023-08-10 10:33:01 +000072template <int Rank, TOSA_REF_TYPE Dtype>
73class OpDim : public GraphNode
74{
75public:
76 OpDim(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
77 virtual ~OpDim();
78
79 virtual int checkTensorAttributes();
80 virtual int eval();
81
82 using InEigenType = typename GetEigenType<Dtype>::type;
83 using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
84 using TIn = Eigen::Tensor<InEigenType, Rank>;
Tai Ly8690a082023-12-18 20:40:24 +000085 using TOut = Eigen::Tensor<OutEigenType, 1>;
Won Jeona21b2e82023-08-10 10:33:01 +000086
87protected:
88 TosaReference::TensorTemplate<TIn>* in;
89 TosaReference::TensorTemplate<TOut>* out;
90 TosaAxisAttribute* attribute;
91};
92
Tai Lya4d748b2023-03-28 22:06:56 +000093template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070094class OpReshape : public GraphNode
95{
96public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000097 OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070098 virtual ~OpReshape();
99
100 virtual int checkTensorAttributes();
101 virtual int eval();
102
103 using InEigenType = typename GetEigenType<Dtype>::type;
104 using OutEigenType = typename GetEigenType<Dtype>::type;
105 using TIn = Eigen::Tensor<InEigenType, InRank>;
106 using TOut = Eigen::Tensor<OutEigenType, OutRank>;
107
108protected:
109 Eigen::array<Eigen::Index, OutRank> array_shape;
110 Eigen::array<Eigen::Index, InRank> in_reverser;
111 Eigen::array<Eigen::Index, OutRank> out_reverser;
112 TosaReference::TensorTemplate<TIn>* in;
Eric Kunzee5e26762020-10-13 16:11:07 -0700113 TosaReference::TensorTemplate<TOut>* out;
114};
115
Tai Lya4d748b2023-03-28 22:06:56 +0000116template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700117class OpReverse : public GraphNode
118{
119public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000120 OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700121 virtual ~OpReverse();
122
123 virtual int checkTensorAttributes();
124 virtual int eval();
125
126 using InEigenType = typename GetEigenType<Dtype>::type;
127 using OutEigenType = typename GetEigenType<Dtype>::type;
128 using TIn = Eigen::Tensor<InEigenType, Rank>;
129 using TOut = Eigen::Tensor<OutEigenType, Rank>;
130
131protected:
132 TosaAxisAttribute* attribute;
133 TosaReference::TensorTemplate<TIn>* in;
134 TosaReference::TensorTemplate<TOut>* out;
135 Eigen::array<bool, Rank> reverse_array;
136};
137
Tai Lya4d748b2023-03-28 22:06:56 +0000138template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700139class OpSlice : public GraphNode
140{
141public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000142 OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 virtual ~OpSlice();
144
145 virtual int checkTensorAttributes();
146 virtual int eval();
147
TatWai Chong01f937a2024-01-24 22:57:07 -0800148 using InEigenType = typename GetEigenType<Dtype>::type;
149 using InEigenShapeType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
150 using OutEigenType = typename GetEigenType<Dtype>::type;
151 using TIn = Eigen::Tensor<InEigenType, Rank>;
TatWai Chong97f1c0e2024-02-05 11:56:46 -0800152 using TSlicing = Eigen::Tensor<InEigenShapeType, 1>;
TatWai Chong01f937a2024-01-24 22:57:07 -0800153 using TOut = Eigen::Tensor<OutEigenType, Rank>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700154
155protected:
Eric Kunzee5e26762020-10-13 16:11:07 -0700156 Eigen::array<Eigen::Index, Rank> begin_array;
157 Eigen::array<Eigen::Index, Rank> size_array;
TatWai Chong97f1c0e2024-02-05 11:56:46 -0800158 TosaReference::TensorTemplate<TSlicing>* start;
159 TosaReference::TensorTemplate<TSlicing>* size;
Eric Kunzee5e26762020-10-13 16:11:07 -0700160 TosaReference::TensorTemplate<TIn>* in;
161 TosaReference::TensorTemplate<TOut>* out;
162};
163
Tai Lya4d748b2023-03-28 22:06:56 +0000164template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700165class OpTileBase : public GraphNode
166{
167public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000168 OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700169 virtual ~OpTileBase();
170
171 virtual int checkTensorAttributes();
172
Tai Ly8690a082023-12-18 20:40:24 +0000173 using InEigenType = typename GetEigenType<Dtype>::type;
174 using InEigenShapeType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
175 using OutEigenType = typename GetEigenType<Dtype>::type;
176 using TIn = Eigen::Tensor<InEigenType, Rank>;
177 using TInMultiples = Eigen::Tensor<InEigenShapeType, 1>;
178 using TOut = Eigen::Tensor<OutEigenType, Rank>;
Eric Kunzee5e26762020-10-13 16:11:07 -0700179
180protected:
181 TosaTileAttribute* attribute;
182 TosaReference::TensorTemplate<TIn>* in;
Tai Ly8690a082023-12-18 20:40:24 +0000183 TosaReference::TensorTemplate<TInMultiples>* multiples;
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 TosaReference::TensorTemplate<TOut>* out;
185};
186
187// primary template for op tile
Tai Lya4d748b2023-03-28 22:06:56 +0000188template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700189class OpTile : public OpTileBase<Rank, Dtype>
190{
191public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000192 OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
193 : OpTileBase<Rank, Dtype>(sgt_, attribute_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700194 {}
195
196protected:
197 virtual int eval();
198};
199
200// partial specialization for specific rank
201#define DEF_OP_TILE_RANK(N) \
Tai Lya4d748b2023-03-28 22:06:56 +0000202 template <TOSA_REF_TYPE Dtype> \
Eric Kunzee5e26762020-10-13 16:11:07 -0700203 class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
204 { \
205 public: \
Tai Lya4d748b2023-03-28 22:06:56 +0000206 OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
207 : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \
Eric Kunzee5e26762020-10-13 16:11:07 -0700208 {} \
209 \
210 protected: \
211 virtual int eval(); \
212 };
213
214DEF_OP_TILE_RANK(1)
215DEF_OP_TILE_RANK(2)
216DEF_OP_TILE_RANK(3)
217DEF_OP_TILE_RANK(4)
Luke Huttona4e48ca2023-02-22 11:53:48 +0000218DEF_OP_TILE_RANK(5)
219DEF_OP_TILE_RANK(6)
Eric Kunzee5e26762020-10-13 16:11:07 -0700220
221#undef DEF_OP_TILE_RANK
222
Tai Lya4d748b2023-03-28 22:06:56 +0000223template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700224class OpTranspose : public GraphNode
225{
226public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700228 virtual ~OpTranspose();
229
230 virtual int checkTensorAttributes();
231 virtual int eval();
232
233 using InEigenType = typename GetEigenType<Dtype>::type;
234 using OutEigenType = typename GetEigenType<Dtype>::type;
235 using TIn = Eigen::Tensor<InEigenType, Rank>;
236 using TOut = Eigen::Tensor<OutEigenType, Rank>;
237
238protected:
239 Eigen::array<int, Rank> perm_array;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000240 TosaTransposeAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700241 TosaReference::TensorTemplate<TIn>* in;
Eric Kunzee5e26762020-10-13 16:11:07 -0700242 TosaReference::TensorTemplate<TOut>* out;
243};
244}; // namespace TosaReference
245
246#endif