blob: 100bd6b6329cd32e8f3552ab13398f582e49dd3e [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_DATA_LAYOUT_H
17#define OPS_DATA_LAYOUT_H
18
19#include "graph_node.h"
20
21using namespace tosa;
22
23namespace TosaReference
24{
25
26template <int Rank, DType Dtype>
27class OpConcat : public GraphNode
28{
29public:
30 OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
31 virtual ~OpConcat();
32
33 virtual int checkTensorAttributes();
34 virtual int eval();
35
36 using InEigenType = typename GetEigenType<Dtype>::type;
37 using OutEigenType = typename GetEigenType<Dtype>::type;
38 using TIn = Eigen::Tensor<InEigenType, Rank>;
39 using TOut = Eigen::Tensor<OutEigenType, Rank>;
40
41protected:
42 Eigen::array<int, Rank> reverser;
43 TosaReference::TensorTemplate<TIn>* lhs;
44 TosaReference::TensorTemplate<TIn>* rhs;
45 TosaAxisAttribute* attribute;
46 TosaReference::TensorTemplate<TOut>* out;
47};
48
49template <int Rank, DType Dtype>
50class OpPad : public GraphNode
51{
52public:
53 OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
54 virtual ~OpPad();
55 virtual int checkTensorAttributes();
56 virtual int eval();
57
58 using InEigenType = typename GetEigenType<Dtype>::type;
59 using OutEigenType = typename GetEigenType<Dtype>::type;
60 using TIn = Eigen::Tensor<InEigenType, Rank>;
61 using TOut = Eigen::Tensor<OutEigenType, Rank>;
62
63protected:
64 Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
65 TosaReference::TensorTemplate<TIn>* in;
66 TosaReference::TensorTemplate<TOut>* out;
67 TosaPadQuantInfo* qinfo;
68};
69
70template <int InRank, int OutRank, DType Dtype>
71class OpReshape : public GraphNode
72{
73public:
74 OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
75 virtual ~OpReshape();
76
77 virtual int checkTensorAttributes();
78 virtual int eval();
79
80 using InEigenType = typename GetEigenType<Dtype>::type;
81 using OutEigenType = typename GetEigenType<Dtype>::type;
82 using TIn = Eigen::Tensor<InEigenType, InRank>;
83 using TOut = Eigen::Tensor<OutEigenType, OutRank>;
84
85protected:
86 Eigen::array<Eigen::Index, OutRank> array_shape;
87 Eigen::array<Eigen::Index, InRank> in_reverser;
88 Eigen::array<Eigen::Index, OutRank> out_reverser;
89 TosaReference::TensorTemplate<TIn>* in;
90 TosaReshapeAttribute* attribute;
91 TosaReference::TensorTemplate<TOut>* out;
92};
93
94template <int Rank, DType Dtype>
95class OpReverse : public GraphNode
96{
97public:
98 OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
99 virtual ~OpReverse();
100
101 virtual int checkTensorAttributes();
102 virtual int eval();
103
104 using InEigenType = typename GetEigenType<Dtype>::type;
105 using OutEigenType = typename GetEigenType<Dtype>::type;
106 using TIn = Eigen::Tensor<InEigenType, Rank>;
107 using TOut = Eigen::Tensor<OutEigenType, Rank>;
108
109protected:
110 TosaAxisAttribute* attribute;
111 TosaReference::TensorTemplate<TIn>* in;
112 TosaReference::TensorTemplate<TOut>* out;
113 Eigen::array<bool, Rank> reverse_array;
114};
115
116template <int Rank, DType Dtype>
117class OpSlice : public GraphNode
118{
119public:
120 OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
121 virtual ~OpSlice();
122
123 virtual int checkTensorAttributes();
124 virtual int eval();
125
126 using InEigenType = typename GetEigenType<Dtype>::type;
127 using OutEigenType = typename GetEigenType<Dtype>::type;
128 using TIn = Eigen::Tensor<InEigenType, Rank>;
129 using TOut = Eigen::Tensor<OutEigenType, Rank>;
130
131protected:
132 TosaSliceAttribute* attribute;
133 Eigen::array<Eigen::Index, Rank> begin_array;
134 Eigen::array<Eigen::Index, Rank> size_array;
135 TosaReference::TensorTemplate<TIn>* in;
136 TosaReference::TensorTemplate<TOut>* out;
137};
138
139template <int Rank, DType Dtype>
140class OpTileBase : public GraphNode
141{
142public:
143 OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
144 virtual ~OpTileBase();
145
146 virtual int checkTensorAttributes();
147
148 using InEigenType = typename GetEigenType<Dtype>::type;
149 using OutEigenType = typename GetEigenType<Dtype>::type;
150 using TIn = Eigen::Tensor<InEigenType, Rank>;
151 using TOut = Eigen::Tensor<OutEigenType, Rank>;
152
153protected:
154 TosaTileAttribute* attribute;
155 TosaReference::TensorTemplate<TIn>* in;
156 TosaReference::TensorTemplate<TOut>* out;
157};
158
159// primary template for op tile
160template <int Rank, DType Dtype>
161class OpTile : public OpTileBase<Rank, Dtype>
162{
163public:
164 OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
165 : OpTileBase<Rank, Dtype>(attribute_, qinfo_, id_)
166 {}
167
168protected:
169 virtual int eval();
170};
171
172// partial specialization for specific rank
173#define DEF_OP_TILE_RANK(N) \
174 template <DType Dtype> \
175 class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
176 { \
177 public: \
178 OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
179 : OpTileBase<N, Dtype>(attribute_, qinfo_, id_) \
180 {} \
181 \
182 protected: \
183 virtual int eval(); \
184 };
185
186DEF_OP_TILE_RANK(1)
187DEF_OP_TILE_RANK(2)
188DEF_OP_TILE_RANK(3)
189DEF_OP_TILE_RANK(4)
190
191#undef DEF_OP_TILE_RANK
192
193template <int Rank, DType Dtype>
194class OpTranspose : public GraphNode
195{
196public:
197 OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
198 virtual ~OpTranspose();
199
200 virtual int checkTensorAttributes();
201 virtual int eval();
202
203 using InEigenType = typename GetEigenType<Dtype>::type;
204 using OutEigenType = typename GetEigenType<Dtype>::type;
205 using TIn = Eigen::Tensor<InEigenType, Rank>;
206 using TOut = Eigen::Tensor<OutEigenType, Rank>;
207
208protected:
209 Eigen::array<int, Rank> perm_array;
210 TosaReference::TensorTemplate<TIn>* in;
211 TosaReference::TensorTemplate<ETensor1<int32_t>>* perm_tensor;
212 TosaReference::TensorTemplate<TOut>* out;
213};
214}; // namespace TosaReference
215
216#endif