blob: d7bddc0003c86c30ce3df82e33d7f15749c5fe94 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_unary.h"
17#include "quant_util.h"
18#include "template_types.h"
19#include <cmath>
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType Dtype>
26UnaryNode<Rank, Dtype>::UnaryNode(const Op& op_, uint64_t id_)
27 : GraphNode(op_, id_)
28{
29 setRequiredOperands(1, 1);
30 setRequiredRank(0, 6);
31
32 fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); };
33}
34
35template <int Rank, DType Dtype>
36UnaryNode<Rank, Dtype>::~UnaryNode()
37{}
38
39template <int Rank, DType Dtype>
40int UnaryNode<Rank, Dtype>::checkTensorAttributes()
41{
42 if (validateRequiredOperands())
43 return 1;
44
45 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
46 {
47 return 1;
48 }
49
50 // output and input must be the same types
51 if (inputs[0]->matchRankSize(*outputs[0]))
52 {
53 printNodeValidationError("UnaryNode: input and output rank must match");
54 return 1;
55 }
56
57 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
58 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
59
60 ASSERT_MEM(a && result);
61
62 return 0;
63}
64
65template <int Rank, DType Dtype>
66int UnaryNode<Rank, Dtype>::eval()
67{
68 this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn);
69
70 return GraphNode::eval();
71}
72
73template <int Rank, DType Dtype>
74int OpAbs<Rank, Dtype>::register_fcn()
75{
76 switch (Dtype)
77 {
78 case DType_FLOAT:
79 case DType_INT32:
80 this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
81 break;
82 default:
83 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
84 }
85
86 return 0;
87}
88
89template <int Rank, DType Dtype>
90int OpBitwiseNot<Rank, Dtype>::register_fcn()
91{
92 switch (Dtype)
93 {
94 case DType_AINT8:
95 case DType_INT16:
96 case DType_INT32:
97 this->fcn = [](InEigenType a) -> OutEigenType { return ~a; };
98 break;
99 default:
100 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
101 }
102
103 return 0;
104}
105
106template <int Rank, DType Dtype>
107int OpCeil<Rank, Dtype>::register_fcn()
108{
109 switch (Dtype)
110 {
111 case DType_FLOAT:
112 this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); };
113 break;
114 default:
115 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
116 }
117
118 return 0;
119}
120
121template <int Rank, DType Dtype>
122int OpClz<Rank, Dtype>::register_fcn()
123{
124 int32_t num_bits;
125 switch (Dtype)
126 {
127 case DType_INT32:
128 num_bits = 32;
129 break;
130 default:
131 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
132 }
133
134 this->fcn = [num_bits](int32_t a) -> int32_t {
135 int32_t leading_zeros = 0;
136 for (int bit = num_bits - 1; bit >= 0; bit--)
137 {
138 if (((a >> bit) & 0x1) == 0)
139 {
140 leading_zeros++;
141 }
142 else
143 {
144 break;
145 }
146 }
147 return leading_zeros;
148 };
149
150 return 0;
151}
152
153template <int Rank, DType Dtype>
154int OpExp<Rank, Dtype>::register_fcn()
155{
156 switch (Dtype)
157 {
158 case DType_FLOAT:
159 this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); };
160 break;
161 default:
162 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
163 }
164
165 return 0;
166}
167
168template <int Rank, DType Dtype>
169int OpFloor<Rank, Dtype>::register_fcn()
170{
171 switch (Dtype)
172 {
173 case DType_FLOAT:
174 this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); };
175 break;
176 default:
177 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
178 }
179
180 return 0;
181}
182
183template <int Rank, DType Dtype>
184int OpLog<Rank, Dtype>::register_fcn()
185{
186 switch (Dtype)
187 {
188 case DType_FLOAT:
189 this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); };
190 break;
191 default:
192 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
193 }
194
195 return 0;
196}
197
198template <int Rank, DType Dtype>
199int OpLogicalNot<Rank, Dtype>::register_fcn()
200{
201 switch (Dtype)
202 {
203 case DType_BOOL:
204 this->fcn = [](InEigenType a) -> OutEigenType { return !a; };
205 break;
206 default:
207 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
208 }
209
210 return 0;
211}
212
213template <int Rank, DType Dtype>
214int OpNegate<Rank, Dtype>::register_fcn()
215{
216 switch (Dtype)
217 {
218 case DType_FLOAT:
219 this->fcn = [](InEigenType a) -> OutEigenType {
220 InEigenType result = -(a);
221 return result;
222 };
223 break;
224 case DType_INT16:
225 case DType_INT32:
226 this->fcn = [](InEigenType a) -> OutEigenType {
227 InEigenType result = -(a);
228 return result;
229 };
230 break;
231 case DType_AINT8:
232 ASSERT(this->qinfo);
233 this->fcn = [this](InEigenType a) -> OutEigenType {
234 InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp();
235 return result;
236 };
237 break;
238 default:
239 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
240 }
241
242 return 0;
243}
244
245template <int Rank, DType Dtype>
246int OpReciprocal<Rank, Dtype>::register_fcn()
247{
248 switch (Dtype)
249 {
250 case DType_FLOAT:
251 this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; };
252 break;
253 default:
254 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
255 }
256
257 return 0;
258}
259
260template <int Rank, DType Dtype>
261int OpRsqrt<Rank, Dtype>::register_fcn()
262{
263 switch (Dtype)
264 {
265 case DType_FLOAT:
266 this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); };
267 break;
268 default:
269 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
270 }
271
272 return 0;
273}
274
275// template explicit instantiation
276DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
277DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
278
279DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8);
280DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
281DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
282
283DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
284
285DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
286
287DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
288
289DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
290
291DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
292
293DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
294
295DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
296DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8);
297DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
298DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
299
300DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
301
302DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);