blob: 024f9a2b203298d8d7c9b5ef53d85c0a278d3652 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Huttona4e48ca2023-02-22 11:53:48 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_DATA_LAYOUT_H
17#define OPS_DATA_LAYOUT_H
18
19#include "graph_node.h"
20
21using namespace tosa;
22
23namespace TosaReference
24{
25
Tai Lya4d748b2023-03-28 22:06:56 +000026template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070027class OpConcat : public GraphNode
28{
29public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070031 virtual ~OpConcat();
32
33 virtual int checkTensorAttributes();
34 virtual int eval();
35
36 using InEigenType = typename GetEigenType<Dtype>::type;
37 using OutEigenType = typename GetEigenType<Dtype>::type;
38 using TIn = Eigen::Tensor<InEigenType, Rank>;
39 using TOut = Eigen::Tensor<OutEigenType, Rank>;
40
41protected:
42 Eigen::array<int, Rank> reverser;
Kevin Chengad15dfa2021-03-04 15:15:03 -080043 std::vector<TosaReference::TensorTemplate<TIn>*> ins;
Eric Kunzee5e26762020-10-13 16:11:07 -070044 TosaAxisAttribute* attribute;
45 TosaReference::TensorTemplate<TOut>* out;
46};
47
Tai Lya4d748b2023-03-28 22:06:56 +000048template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070049class OpPad : public GraphNode
50{
51public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000052 OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070053 virtual ~OpPad();
54 virtual int checkTensorAttributes();
55 virtual int eval();
56
57 using InEigenType = typename GetEigenType<Dtype>::type;
58 using OutEigenType = typename GetEigenType<Dtype>::type;
59 using TIn = Eigen::Tensor<InEigenType, Rank>;
60 using TOut = Eigen::Tensor<OutEigenType, Rank>;
61
62protected:
63 Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
64 TosaReference::TensorTemplate<TIn>* in;
65 TosaReference::TensorTemplate<TOut>* out;
Kevin Chengfe392ce2021-10-18 21:51:55 +000066 TosaPadAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -070067};
68
Won Jeona21b2e82023-08-10 10:33:01 +000069template <int Rank, TOSA_REF_TYPE Dtype>
70class OpDim : public GraphNode
71{
72public:
73 OpDim(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
74 virtual ~OpDim();
75
76 virtual int checkTensorAttributes();
77 virtual int eval();
78
79 using InEigenType = typename GetEigenType<Dtype>::type;
80 using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
81 using TIn = Eigen::Tensor<InEigenType, Rank>;
82 using TOut = Eigen::Tensor<OutEigenType, 0>;
83
84protected:
85 TosaReference::TensorTemplate<TIn>* in;
86 TosaReference::TensorTemplate<TOut>* out;
87 TosaAxisAttribute* attribute;
88};
89
Tai Lya4d748b2023-03-28 22:06:56 +000090template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070091class OpReshape : public GraphNode
92{
93public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000094 OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070095 virtual ~OpReshape();
96
97 virtual int checkTensorAttributes();
98 virtual int eval();
99
100 using InEigenType = typename GetEigenType<Dtype>::type;
101 using OutEigenType = typename GetEigenType<Dtype>::type;
102 using TIn = Eigen::Tensor<InEigenType, InRank>;
103 using TOut = Eigen::Tensor<OutEigenType, OutRank>;
104
105protected:
106 Eigen::array<Eigen::Index, OutRank> array_shape;
107 Eigen::array<Eigen::Index, InRank> in_reverser;
108 Eigen::array<Eigen::Index, OutRank> out_reverser;
109 TosaReference::TensorTemplate<TIn>* in;
110 TosaReshapeAttribute* attribute;
111 TosaReference::TensorTemplate<TOut>* out;
112};
113
Tai Lya4d748b2023-03-28 22:06:56 +0000114template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700115class OpReverse : public GraphNode
116{
117public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000118 OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700119 virtual ~OpReverse();
120
121 virtual int checkTensorAttributes();
122 virtual int eval();
123
124 using InEigenType = typename GetEigenType<Dtype>::type;
125 using OutEigenType = typename GetEigenType<Dtype>::type;
126 using TIn = Eigen::Tensor<InEigenType, Rank>;
127 using TOut = Eigen::Tensor<OutEigenType, Rank>;
128
129protected:
130 TosaAxisAttribute* attribute;
131 TosaReference::TensorTemplate<TIn>* in;
132 TosaReference::TensorTemplate<TOut>* out;
133 Eigen::array<bool, Rank> reverse_array;
134};
135
Tai Lya4d748b2023-03-28 22:06:56 +0000136template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700137class OpSlice : public GraphNode
138{
139public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000140 OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 virtual ~OpSlice();
142
143 virtual int checkTensorAttributes();
144 virtual int eval();
145
146 using InEigenType = typename GetEigenType<Dtype>::type;
147 using OutEigenType = typename GetEigenType<Dtype>::type;
148 using TIn = Eigen::Tensor<InEigenType, Rank>;
149 using TOut = Eigen::Tensor<OutEigenType, Rank>;
150
151protected:
152 TosaSliceAttribute* attribute;
153 Eigen::array<Eigen::Index, Rank> begin_array;
154 Eigen::array<Eigen::Index, Rank> size_array;
155 TosaReference::TensorTemplate<TIn>* in;
156 TosaReference::TensorTemplate<TOut>* out;
157};
158
Tai Lya4d748b2023-03-28 22:06:56 +0000159template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700160class OpTileBase : public GraphNode
161{
162public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000163 OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 virtual ~OpTileBase();
165
166 virtual int checkTensorAttributes();
167
168 using InEigenType = typename GetEigenType<Dtype>::type;
169 using OutEigenType = typename GetEigenType<Dtype>::type;
170 using TIn = Eigen::Tensor<InEigenType, Rank>;
171 using TOut = Eigen::Tensor<OutEigenType, Rank>;
172
173protected:
174 TosaTileAttribute* attribute;
175 TosaReference::TensorTemplate<TIn>* in;
176 TosaReference::TensorTemplate<TOut>* out;
177};
178
179// primary template for op tile
Tai Lya4d748b2023-03-28 22:06:56 +0000180template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700181class OpTile : public OpTileBase<Rank, Dtype>
182{
183public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000184 OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
185 : OpTileBase<Rank, Dtype>(sgt_, attribute_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 {}
187
188protected:
189 virtual int eval();
190};
191
192// partial specialization for specific rank
193#define DEF_OP_TILE_RANK(N) \
Tai Lya4d748b2023-03-28 22:06:56 +0000194 template <TOSA_REF_TYPE Dtype> \
Eric Kunzee5e26762020-10-13 16:11:07 -0700195 class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
196 { \
197 public: \
Tai Lya4d748b2023-03-28 22:06:56 +0000198 OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
199 : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \
Eric Kunzee5e26762020-10-13 16:11:07 -0700200 {} \
201 \
202 protected: \
203 virtual int eval(); \
204 };
205
206DEF_OP_TILE_RANK(1)
207DEF_OP_TILE_RANK(2)
208DEF_OP_TILE_RANK(3)
209DEF_OP_TILE_RANK(4)
Luke Huttona4e48ca2023-02-22 11:53:48 +0000210DEF_OP_TILE_RANK(5)
211DEF_OP_TILE_RANK(6)
Eric Kunzee5e26762020-10-13 16:11:07 -0700212
213#undef DEF_OP_TILE_RANK
214
Tai Lya4d748b2023-03-28 22:06:56 +0000215template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700216class OpTranspose : public GraphNode
217{
218public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000219 OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700220 virtual ~OpTranspose();
221
222 virtual int checkTensorAttributes();
223 virtual int eval();
224
225 using InEigenType = typename GetEigenType<Dtype>::type;
226 using OutEigenType = typename GetEigenType<Dtype>::type;
227 using TIn = Eigen::Tensor<InEigenType, Rank>;
228 using TOut = Eigen::Tensor<OutEigenType, Rank>;
229
230protected:
231 Eigen::array<int, Rank> perm_array;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000232 TosaTransposeAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700233 TosaReference::TensorTemplate<TIn>* in;
Eric Kunzee5e26762020-10-13 16:11:07 -0700234 TosaReference::TensorTemplate<TOut>* out;
235};
236}; // namespace TosaReference
237
238#endif