blob: 3a6cb0d40044783a0c838ed87b1c487aaa75243a [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Huttona4e48ca2023-02-22 11:53:48 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_DATA_LAYOUT_H
17#define OPS_DATA_LAYOUT_H
18
19#include "graph_node.h"
20
21using namespace tosa;
22
23namespace TosaReference
24{
25
26template <int Rank, DType Dtype>
27class OpConcat : public GraphNode
28{
29public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070031 virtual ~OpConcat();
32
33 virtual int checkTensorAttributes();
34 virtual int eval();
35
36 using InEigenType = typename GetEigenType<Dtype>::type;
37 using OutEigenType = typename GetEigenType<Dtype>::type;
38 using TIn = Eigen::Tensor<InEigenType, Rank>;
39 using TOut = Eigen::Tensor<OutEigenType, Rank>;
40
41protected:
42 Eigen::array<int, Rank> reverser;
Kevin Chengad15dfa2021-03-04 15:15:03 -080043 std::vector<TosaReference::TensorTemplate<TIn>*> ins;
Eric Kunzee5e26762020-10-13 16:11:07 -070044 TosaAxisAttribute* attribute;
45 TosaReference::TensorTemplate<TOut>* out;
46};
47
48template <int Rank, DType Dtype>
49class OpPad : public GraphNode
50{
51public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000052 OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070053 virtual ~OpPad();
54 virtual int checkTensorAttributes();
55 virtual int eval();
56
57 using InEigenType = typename GetEigenType<Dtype>::type;
58 using OutEigenType = typename GetEigenType<Dtype>::type;
59 using TIn = Eigen::Tensor<InEigenType, Rank>;
60 using TOut = Eigen::Tensor<OutEigenType, Rank>;
61
62protected:
63 Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
64 TosaReference::TensorTemplate<TIn>* in;
65 TosaReference::TensorTemplate<TOut>* out;
Kevin Chengfe392ce2021-10-18 21:51:55 +000066 TosaPadAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -070067};
68
69template <int InRank, int OutRank, DType Dtype>
70class OpReshape : public GraphNode
71{
72public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000073 OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070074 virtual ~OpReshape();
75
76 virtual int checkTensorAttributes();
77 virtual int eval();
78
79 using InEigenType = typename GetEigenType<Dtype>::type;
80 using OutEigenType = typename GetEigenType<Dtype>::type;
81 using TIn = Eigen::Tensor<InEigenType, InRank>;
82 using TOut = Eigen::Tensor<OutEigenType, OutRank>;
83
84protected:
85 Eigen::array<Eigen::Index, OutRank> array_shape;
86 Eigen::array<Eigen::Index, InRank> in_reverser;
87 Eigen::array<Eigen::Index, OutRank> out_reverser;
88 TosaReference::TensorTemplate<TIn>* in;
89 TosaReshapeAttribute* attribute;
90 TosaReference::TensorTemplate<TOut>* out;
91};
92
93template <int Rank, DType Dtype>
94class OpReverse : public GraphNode
95{
96public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000097 OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070098 virtual ~OpReverse();
99
100 virtual int checkTensorAttributes();
101 virtual int eval();
102
103 using InEigenType = typename GetEigenType<Dtype>::type;
104 using OutEigenType = typename GetEigenType<Dtype>::type;
105 using TIn = Eigen::Tensor<InEigenType, Rank>;
106 using TOut = Eigen::Tensor<OutEigenType, Rank>;
107
108protected:
109 TosaAxisAttribute* attribute;
110 TosaReference::TensorTemplate<TIn>* in;
111 TosaReference::TensorTemplate<TOut>* out;
112 Eigen::array<bool, Rank> reverse_array;
113};
114
115template <int Rank, DType Dtype>
116class OpSlice : public GraphNode
117{
118public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000119 OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700120 virtual ~OpSlice();
121
122 virtual int checkTensorAttributes();
123 virtual int eval();
124
125 using InEigenType = typename GetEigenType<Dtype>::type;
126 using OutEigenType = typename GetEigenType<Dtype>::type;
127 using TIn = Eigen::Tensor<InEigenType, Rank>;
128 using TOut = Eigen::Tensor<OutEigenType, Rank>;
129
130protected:
131 TosaSliceAttribute* attribute;
132 Eigen::array<Eigen::Index, Rank> begin_array;
133 Eigen::array<Eigen::Index, Rank> size_array;
134 TosaReference::TensorTemplate<TIn>* in;
135 TosaReference::TensorTemplate<TOut>* out;
136};
137
138template <int Rank, DType Dtype>
139class OpTileBase : public GraphNode
140{
141public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000142 OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 virtual ~OpTileBase();
144
145 virtual int checkTensorAttributes();
146
147 using InEigenType = typename GetEigenType<Dtype>::type;
148 using OutEigenType = typename GetEigenType<Dtype>::type;
149 using TIn = Eigen::Tensor<InEigenType, Rank>;
150 using TOut = Eigen::Tensor<OutEigenType, Rank>;
151
152protected:
153 TosaTileAttribute* attribute;
154 TosaReference::TensorTemplate<TIn>* in;
155 TosaReference::TensorTemplate<TOut>* out;
156};
157
158// primary template for op tile
159template <int Rank, DType Dtype>
160class OpTile : public OpTileBase<Rank, Dtype>
161{
162public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000163 OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
164 : OpTileBase<Rank, Dtype>(sgt_, attribute_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700165 {}
166
167protected:
168 virtual int eval();
169};
170
171// partial specialization for specific rank
172#define DEF_OP_TILE_RANK(N) \
173 template <DType Dtype> \
174 class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
175 { \
176 public: \
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000177 OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
178 : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 {} \
180 \
181 protected: \
182 virtual int eval(); \
183 };
184
185DEF_OP_TILE_RANK(1)
186DEF_OP_TILE_RANK(2)
187DEF_OP_TILE_RANK(3)
188DEF_OP_TILE_RANK(4)
Luke Huttona4e48ca2023-02-22 11:53:48 +0000189DEF_OP_TILE_RANK(5)
190DEF_OP_TILE_RANK(6)
Eric Kunzee5e26762020-10-13 16:11:07 -0700191
192#undef DEF_OP_TILE_RANK
193
194template <int Rank, DType Dtype>
195class OpTranspose : public GraphNode
196{
197public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000198 OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700199 virtual ~OpTranspose();
200
201 virtual int checkTensorAttributes();
202 virtual int eval();
203
204 using InEigenType = typename GetEigenType<Dtype>::type;
205 using OutEigenType = typename GetEigenType<Dtype>::type;
206 using TIn = Eigen::Tensor<InEigenType, Rank>;
207 using TOut = Eigen::Tensor<OutEigenType, Rank>;
208
209protected:
210 Eigen::array<int, Rank> perm_array;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000211 TosaTransposeAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 TosaReference::TensorTemplate<TIn>* in;
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 TosaReference::TensorTemplate<TOut>* out;
214};
215}; // namespace TosaReference
216
217#endif