blob: c6513ae4bd35224dc4227a8110cbff94ad281f8f [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_DATA_LAYOUT_H
17#define OPS_DATA_LAYOUT_H
18
19#include "graph_node.h"
20
21using namespace tosa;
22
23namespace TosaReference
24{
25
26template <int Rank, DType Dtype>
27class OpConcat : public GraphNode
28{
29public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070031 virtual ~OpConcat();
32
33 virtual int checkTensorAttributes();
34 virtual int eval();
35
36 using InEigenType = typename GetEigenType<Dtype>::type;
37 using OutEigenType = typename GetEigenType<Dtype>::type;
38 using TIn = Eigen::Tensor<InEigenType, Rank>;
39 using TOut = Eigen::Tensor<OutEigenType, Rank>;
40
41protected:
42 Eigen::array<int, Rank> reverser;
Kevin Chengad15dfa2021-03-04 15:15:03 -080043 std::vector<TosaReference::TensorTemplate<TIn>*> ins;
Eric Kunzee5e26762020-10-13 16:11:07 -070044 TosaAxisAttribute* attribute;
45 TosaReference::TensorTemplate<TOut>* out;
46};
47
48template <int Rank, DType Dtype>
49class OpPad : public GraphNode
50{
51public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000052 OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070053 virtual ~OpPad();
54 virtual int checkTensorAttributes();
55 virtual int eval();
56
57 using InEigenType = typename GetEigenType<Dtype>::type;
58 using OutEigenType = typename GetEigenType<Dtype>::type;
59 using TIn = Eigen::Tensor<InEigenType, Rank>;
60 using TOut = Eigen::Tensor<OutEigenType, Rank>;
61
62protected:
63 Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
64 TosaReference::TensorTemplate<TIn>* in;
65 TosaReference::TensorTemplate<TOut>* out;
Kevin Chengfe392ce2021-10-18 21:51:55 +000066 TosaPadAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -070067};
68
69template <int InRank, int OutRank, DType Dtype>
70class OpReshape : public GraphNode
71{
72public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000073 OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070074 virtual ~OpReshape();
75
76 virtual int checkTensorAttributes();
77 virtual int eval();
78
79 using InEigenType = typename GetEigenType<Dtype>::type;
80 using OutEigenType = typename GetEigenType<Dtype>::type;
81 using TIn = Eigen::Tensor<InEigenType, InRank>;
82 using TOut = Eigen::Tensor<OutEigenType, OutRank>;
83
84protected:
85 Eigen::array<Eigen::Index, OutRank> array_shape;
86 Eigen::array<Eigen::Index, InRank> in_reverser;
87 Eigen::array<Eigen::Index, OutRank> out_reverser;
88 TosaReference::TensorTemplate<TIn>* in;
89 TosaReshapeAttribute* attribute;
90 TosaReference::TensorTemplate<TOut>* out;
91};
92
93template <int Rank, DType Dtype>
94class OpReverse : public GraphNode
95{
96public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000097 OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070098 virtual ~OpReverse();
99
100 virtual int checkTensorAttributes();
101 virtual int eval();
102
103 using InEigenType = typename GetEigenType<Dtype>::type;
104 using OutEigenType = typename GetEigenType<Dtype>::type;
105 using TIn = Eigen::Tensor<InEigenType, Rank>;
106 using TOut = Eigen::Tensor<OutEigenType, Rank>;
107
108protected:
109 TosaAxisAttribute* attribute;
110 TosaReference::TensorTemplate<TIn>* in;
111 TosaReference::TensorTemplate<TOut>* out;
112 Eigen::array<bool, Rank> reverse_array;
113};
114
115template <int Rank, DType Dtype>
116class OpSlice : public GraphNode
117{
118public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000119 OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700120 virtual ~OpSlice();
121
122 virtual int checkTensorAttributes();
123 virtual int eval();
124
125 using InEigenType = typename GetEigenType<Dtype>::type;
126 using OutEigenType = typename GetEigenType<Dtype>::type;
127 using TIn = Eigen::Tensor<InEigenType, Rank>;
128 using TOut = Eigen::Tensor<OutEigenType, Rank>;
129
130protected:
131 TosaSliceAttribute* attribute;
132 Eigen::array<Eigen::Index, Rank> begin_array;
133 Eigen::array<Eigen::Index, Rank> size_array;
134 TosaReference::TensorTemplate<TIn>* in;
135 TosaReference::TensorTemplate<TOut>* out;
136};
137
138template <int Rank, DType Dtype>
139class OpTileBase : public GraphNode
140{
141public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000142 OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 virtual ~OpTileBase();
144
145 virtual int checkTensorAttributes();
146
147 using InEigenType = typename GetEigenType<Dtype>::type;
148 using OutEigenType = typename GetEigenType<Dtype>::type;
149 using TIn = Eigen::Tensor<InEigenType, Rank>;
150 using TOut = Eigen::Tensor<OutEigenType, Rank>;
151
152protected:
153 TosaTileAttribute* attribute;
154 TosaReference::TensorTemplate<TIn>* in;
155 TosaReference::TensorTemplate<TOut>* out;
156};
157
158// primary template for op tile
159template <int Rank, DType Dtype>
160class OpTile : public OpTileBase<Rank, Dtype>
161{
162public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000163 OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
164 : OpTileBase<Rank, Dtype>(sgt_, attribute_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700165 {}
166
167protected:
168 virtual int eval();
169};
170
171// partial specialization for specific rank
172#define DEF_OP_TILE_RANK(N) \
173 template <DType Dtype> \
174 class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
175 { \
176 public: \
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000177 OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
178 : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 {} \
180 \
181 protected: \
182 virtual int eval(); \
183 };
184
185DEF_OP_TILE_RANK(1)
186DEF_OP_TILE_RANK(2)
187DEF_OP_TILE_RANK(3)
188DEF_OP_TILE_RANK(4)
189
190#undef DEF_OP_TILE_RANK
191
192template <int Rank, DType Dtype>
193class OpTranspose : public GraphNode
194{
195public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000196 OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700197 virtual ~OpTranspose();
198
199 virtual int checkTensorAttributes();
200 virtual int eval();
201
202 using InEigenType = typename GetEigenType<Dtype>::type;
203 using OutEigenType = typename GetEigenType<Dtype>::type;
204 using TIn = Eigen::Tensor<InEigenType, Rank>;
205 using TOut = Eigen::Tensor<OutEigenType, Rank>;
206
207protected:
208 Eigen::array<int, Rank> perm_array;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000209 TosaTransposeAttribute* attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700210 TosaReference::TensorTemplate<TIn>* in;
Eric Kunzee5e26762020-10-13 16:11:07 -0700211 TosaReference::TensorTemplate<TOut>* out;
212};
213}; // namespace TosaReference
214
215#endif