blob: 415cd1c76207949fdf0e53ab59712489fc968a09 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
27 const Op& op_,
28 TosaQuantInfoBase* qinfo_,
29 uint64_t id_)
30 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070031{
32 setRequiredOperands(2, 1);
33 setRequiredRank(0, 6);
34
Kevin Chengc42addc2021-09-28 15:41:57 -070035 a = b = nullptr;
36 result = nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -070037
38 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
39}
40
41template <int Rank, DType InDtype, DType OutDtype>
42BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
43{}
44
45template <int Rank, DType InDtype, DType OutDtype>
46int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
47{
48 if (validateRequiredOperands())
49 return 1;
50
51 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
52 {
53 return 1;
54 }
55
Kevin Chengc42addc2021-09-28 15:41:57 -070056 // A & B must be the same rank and types
57 if (inputs[0]->matchRankType(*inputs[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -070058 {
59 printNodeValidationError("Binary operator input types must match");
60 return 1;
61 }
62
Kevin Chengcc61be32021-10-14 17:09:57 -070063 if (inputs[0]->matchRank(*outputs[0]))
Kevin Cheng478101b2021-10-04 10:43:14 -070064 {
65 std::string err =
Kevin Chengcc61be32021-10-14 17:09:57 -070066 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
Kevin Cheng478101b2021-10-04 10:43:14 -070067 printNodeValidationError(err.c_str());
68 return 1;
69 }
Eric Kunzee5e26762020-10-13 16:11:07 -070070
Kevin Chengcc61be32021-10-14 17:09:57 -070071 ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
72
Kevin Chengc42addc2021-09-28 15:41:57 -070073 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
74 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
Eric Kunzee5e26762020-10-13 16:11:07 -070075 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
76
Kevin Chengc42addc2021-09-28 15:41:57 -070077 ASSERT_MEM(a && b && result);
Eric Kunzee5e26762020-10-13 16:11:07 -070078
79 return 0;
80}
81
82template <int Rank, DType InDtype, DType OutDtype>
83int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
84{
85 auto output_shape = result->getTensor().dimensions();
86
87 std::vector<int> a_shape, b_shape;
88
Kevin Chengc42addc2021-09-28 15:41:57 -070089 a_shape = a->getShape();
90 b_shape = b->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070091
Kevin Chengc42addc2021-09-28 15:41:57 -070092 for (int i = 0; i < (int)a_shape.size(); i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070093 {
94 if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
95 {
96 bcast_a[i] = output_shape[i];
97 }
98 else
99 {
100 bcast_a[i] = 1;
101 }
102 if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
103 {
104 bcast_b[i] = output_shape[i];
105 }
106 else
107 {
108 bcast_b[i] = 1;
109 }
110 }
111
112 return 0;
113}
114
115template <int Rank, DType InDtype, DType OutDtype>
116int BinaryNode<Rank, InDtype, OutDtype>::eval()
117{
118 this->broadcast();
119
120 Eigen::array<int, Rank> reshaper;
121 reshaper.fill(1);
122 TIn ia, ib;
123
Kevin Chengc42addc2021-09-28 15:41:57 -0700124 ia = this->a->getTensor().broadcast(this->bcast_a);
125 ib = this->b->getTensor().broadcast(this->bcast_b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700126
127 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
128
129 return GraphNode::eval();
130}
131
132// still need to partial specialize this, or Eigen will throw static assertion
133template <DType InDtype, DType OutDtype>
134int BinaryNode<0, InDtype, OutDtype>::eval()
135{
136 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
137
138 return GraphNode::eval();
139}
140
141template <int Rank, DType Dtype>
142int OpAdd<Rank, Dtype>::register_fcn()
143{
144 switch (InDtype)
145 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700146 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100147 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
148 int64_t res_in_64 = static_cast<int64_t>(a) + b;
149 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
150 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
151 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
152 return static_cast<InEigenType>(res_in_64);
153 };
154 break;
155 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700156 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
157 break;
158 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700159 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700160 }
161
162 return 0;
163}
164
165template <int Rank, DType Dtype>
166int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
167{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800168 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700169 int32_t num_bits = 0;
170 switch (Dtype)
171 {
172 case DType_INT8:
173 num_bits = 8;
174 break;
175 case DType_INT16:
176 num_bits = 16;
177 break;
178 case DType_INT32:
179 num_bits = 32;
180 break;
181 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700182 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 }
184
Kevin Chengaee1fac2020-11-11 13:54:06 -0800185 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700186 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
187 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800188
189 InEigenType acc = a >> b;
190
191 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
192 {
193 acc++;
194 }
195
196 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700197 };
198
199 return 0;
200}
201
202template <int Rank, DType Dtype>
203int OpBitwiseAnd<Rank, Dtype>::register_fcn()
204{
205 switch (Dtype)
206 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800207 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700208 case DType_INT16:
209 case DType_INT32:
210 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
211 break;
212 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700213 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700214 }
215
216 return 0;
217}
218
219template <int Rank, DType Dtype>
220int OpBitwiseOr<Rank, Dtype>::register_fcn()
221{
222 switch (Dtype)
223 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800224 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700225 case DType_INT16:
226 case DType_INT32:
227 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
228 break;
229 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700230 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700231 }
232
233 return 0;
234}
235
236template <int Rank, DType Dtype>
237int OpBitwiseXor<Rank, Dtype>::register_fcn()
238{
239 switch (Dtype)
240 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800241 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700242 case DType_INT16:
243 case DType_INT32:
244 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
245 break;
246 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700247 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700248 }
249
250 return 0;
251}
252
253template <int Rank, DType Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100254int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700255{
256 switch (InDtype)
257 {
258 case DType_INT32:
259 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100260 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700261 int64_t res_in_64 = static_cast<int64_t>(a) / b;
262 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100263 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
264 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700265 return static_cast<InEigenType>(res_in_64);
266 };
267 break;
268 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700269 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700270 }
271
272 return 0;
273}
274
275template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700276int OpLogicalAnd<Rank, Dtype>::register_fcn()
277{
278 switch (Dtype)
279 {
280 case DType_BOOL:
281 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
282 break;
283 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700284 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700285 }
286
287 return 0;
288}
289
290template <int Rank, DType Dtype>
291int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
292{
293 switch (Dtype)
294 {
295 case DType_INT8:
296 case DType_INT16:
297 case DType_INT32:
298 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
299 break;
300 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700301 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700302 }
303
304 return 0;
305}
306
307template <int Rank, DType Dtype>
308int OpLogicalRightShift<Rank, Dtype>::register_fcn()
309{
310 int32_t num_bits = 0;
311 switch (Dtype)
312 {
313 case DType_INT8:
314 num_bits = 8;
315 break;
316 case DType_INT16:
317 num_bits = 16;
318 break;
319 case DType_INT32:
320 num_bits = 32;
321 break;
322 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700323 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700324 }
325
326 this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
327 uint32_t mask = ONES_MASK(num_bits) >> b;
328 return (a >> b) & mask;
329 };
330
331 return 0;
332}
333
334template <int Rank, DType Dtype>
335int OpLogicalOr<Rank, Dtype>::register_fcn()
336{
337 switch (Dtype)
338 {
339 case DType_BOOL:
340 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
341 break;
342 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700343 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700344 }
345
346 return 0;
347}
348
349template <int Rank, DType Dtype>
350int OpLogicalXor<Rank, Dtype>::register_fcn()
351{
352 switch (Dtype)
353 {
354 case DType_BOOL:
355 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
356 break;
357 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700358 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700359 }
360
361 return 0;
362}
363
364template <int Rank, DType Dtype>
365int OpMaximum<Rank, Dtype>::register_fcn()
366{
367 switch (Dtype)
368 {
369 case DType_FLOAT:
370 case DType_INT32:
371 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
372 break;
373 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700374 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700375 }
376
377 return 0;
378}
379
380template <int Rank, DType Dtype>
381int OpMinimum<Rank, Dtype>::register_fcn()
382{
383 switch (Dtype)
384 {
385 case DType_FLOAT:
386 case DType_INT32:
387 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
388 break;
389 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700390 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 }
392
393 return 0;
394}
395
396template <int Rank, DType InDtype, DType OutDtype>
397int OpMul<Rank, InDtype, OutDtype>::register_fcn()
398{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800399 int32_t shift = attribute->shift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800400
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 switch (InDtype)
402 {
403 case DType_FLOAT:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800404 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
405 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700406 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800407 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
408 int64_t result;
409 if (shift > 0)
410 {
411 int64_t round = 1L << (shift - 1);
Kevin Cheng9257fd52021-04-14 15:55:31 -0700412 result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800413 result = result >> shift;
414
Kevin Chengacb550f2021-06-29 15:32:19 -0700415 REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
416 result, QMin, QMax);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800417 }
418 else
419 {
Kevin Chengc42addc2021-09-28 15:41:57 -0700420 result = static_cast<int64_t>(a) * b;
Jeremy Johnson90347472021-09-06 12:04:07 +0100421 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
422 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
423 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
424 return static_cast<InEigenType>(result);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800425 }
426
427 return static_cast<OutEigenType>(result);
428 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700429 break;
430 case DType_INT8:
431 case DType_INT16:
432 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
433 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
434
435 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
436
437 return clamped_output;
438 };
439 break;
440 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700441 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700442 }
443
444 return 0;
445}
446
447template <int Rank, DType Dtype>
448int OpPow<Rank, Dtype>::register_fcn()
449{
450 switch (Dtype)
451 {
452 case DType_FLOAT:
453 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
454 break;
455 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700456 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700457 }
458
459 return 0;
460}
461
462template <int Rank, DType Dtype>
463int OpSub<Rank, Dtype>::register_fcn()
464{
465 switch (InDtype)
466 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700467 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100468 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
469 int64_t res_in_64 = static_cast<int64_t>(a) - b;
470 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
471 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
472 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
473 return static_cast<InEigenType>(res_in_64);
474 };
475 break;
476 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700477 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
478 break;
479 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700480 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700481 }
482
483 return 0;
484}
485
Kevin Cheng571f7182021-05-24 17:20:01 -0700486template <int Rank, DType InDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700487OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
488 TosaAttributeBase* attribute_,
489 TosaQuantInfoBase* qinfo_,
490 uint64_t id_)
491 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700492{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000493 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700494 setRequiredRank(0, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000495
496 INIT_ATTRIBUTE(Table);
Eric Kunzee5e26762020-10-13 16:11:07 -0700497}
498
Kevin Cheng571f7182021-05-24 17:20:01 -0700499template <int Rank, DType InDtype>
500OpTable<Rank, InDtype>::~OpTable()
Eric Kunzee5e26762020-10-13 16:11:07 -0700501{}
502
Kevin Cheng571f7182021-05-24 17:20:01 -0700503template <int Rank, DType InDtype>
504int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700505{
506 if (validateRequiredOperands())
507 return 1;
508
509 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
510 {
511 return 1;
512 }
513
Kevin Chengfe392ce2021-10-18 21:51:55 +0000514 ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
515 ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries);
516
517 for (uint32_t i = 0; i < TableNumEntries; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700518 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000519 table[i] = (TableEigenType)attribute->table()[i];
Eric Kunzee5e26762020-10-13 16:11:07 -0700520 }
521
Kevin Chengfe392ce2021-10-18 21:51:55 +0000522 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
523 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700524
Kevin Chengfe392ce2021-10-18 21:51:55 +0000525 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700526
527 return 0;
528}
529
Kevin Cheng571f7182021-05-24 17:20:01 -0700530template <int Rank, DType InDtype>
531int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700532{
Kevin Cheng571f7182021-05-24 17:20:01 -0700533 switch (InDtype)
534 {
535 case DType_INT8:
536 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
537 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
538 int32_t index = input_truncated - QInMin;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000539 int32_t value = table[index];
Eric Kunzee5e26762020-10-13 16:11:07 -0700540
Kevin Cheng571f7182021-05-24 17:20:01 -0700541 return value;
542 });
543 break;
544 case DType_INT16:
545 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
546 // 1. make sure input is int16 range
547 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
Kevin Cheng571f7182021-05-24 17:20:01 -0700549 // 2. calculate index and interpolation fraction
550 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
551 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
552 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
Kevin Cheng571f7182021-05-24 17:20:01 -0700554 // 3. interpolate, generate 16.7 (23-bit) output
Kevin Chengfe392ce2021-10-18 21:51:55 +0000555 int32_t base = table[index];
556 int32_t next = table[index + 1];
Kevin Cheng571f7182021-05-24 17:20:01 -0700557 int32_t value = (base << 7) + (next - base) * frac;
558
559 return value;
560 });
561 break;
562 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700563 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700564 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700565
566 return GraphNode::eval();
567}
568
569// template explicit instantiation
570DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
571DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
572
573DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
574DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
575DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
576
Kevin Cheng3a478572021-01-22 17:21:02 -0800577DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700578DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
579DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
580
Kevin Cheng3a478572021-01-22 17:21:02 -0800581DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700582DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
583DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
584
Kevin Cheng3a478572021-01-22 17:21:02 -0800585DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700586DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
587DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
588
Matthew Haddon459443c2021-08-23 16:43:13 +0100589DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700590
Eric Kunzee5e26762020-10-13 16:11:07 -0700591DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
592
593DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
594DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
595DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
596
597DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
598DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
599DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
600
601DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
602
603DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
604
605DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
606DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
607
608DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
609DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
610
611DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
612DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
613DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
614DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
615
616DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
617
618DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
619DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
620
Kevin Cheng571f7182021-05-24 17:20:01 -0700621DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
622DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700623
624DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
625DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);