blob: c33f64679f64a082581ec1cbe9738471f9be855e [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
27 const Op& op_,
28 TosaQuantInfoBase* qinfo_,
29 uint64_t id_)
30 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070031{
32 setRequiredOperands(2, 1);
33 setRequiredRank(0, 6);
34
35 a_rank = b_rank = max_input_rank = -1;
36 a = b = nullptr;
37 a_rank0 = b_rank0 = nullptr;
38 result = nullptr;
39
40 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
41}
42
43template <int Rank, DType InDtype, DType OutDtype>
44BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
45{}
46
47template <int Rank, DType InDtype, DType OutDtype>
48int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
49{
50 if (validateRequiredOperands())
51 return 1;
52
53 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
54 {
55 return 1;
56 }
57
58 a_rank = inputs[0]->getRank();
59 b_rank = inputs[1]->getRank();
60 if (a_rank != 0 && b_rank != 0 && a_rank != b_rank)
61 {
62 printNodeValidationError("Binary operator input ranks must match");
63 return 1;
64 }
65
66 max_input_rank = a_rank >= b_rank ? a_rank : b_rank;
67
68 // A & B must be the same types
69 if (inputs[0]->matchType(*inputs[1]))
70 {
71 printNodeValidationError("Binary operator input types must match");
72 return 1;
73 }
74
75 // Result's geometry must match, but the type may be wider
76 if (outputs[0]->getRank() != max_input_rank)
77 {
78 printNodeValidationError("Binary operator input and output genometry must match");
79 return 1;
80 }
81
82 if (a_rank == max_input_rank)
83 {
84 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
85 }
86 else
87 {
88 a_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[0]);
89 }
90
91 if (b_rank == max_input_rank)
92 {
93 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
94 }
95 else
96 {
97 b_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[1]);
98 }
99
100 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
101
102 // either a or b can be rank0
103 // a_rank0 and b_rank0 can't be valid at the same time.
104 // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0'
105 ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result);
106
107 return 0;
108}
109
110template <int Rank, DType InDtype, DType OutDtype>
111int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
112{
113 auto output_shape = result->getTensor().dimensions();
114
115 std::vector<int> a_shape, b_shape;
116
117 if (a_rank == max_input_rank)
118 {
119 a_shape = a->getShape();
120 }
121 else
122 {
123 a_shape.assign(max_input_rank, 1);
124 }
125
126 if (b_rank == max_input_rank)
127 {
128 b_shape = b->getShape();
129 }
130 else
131 {
132 b_shape.assign(max_input_rank, 1);
133 }
134
135 for (int i = 0; i < max_input_rank; i++)
136 {
137 if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
138 {
139 bcast_a[i] = output_shape[i];
140 }
141 else
142 {
143 bcast_a[i] = 1;
144 }
145 if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
146 {
147 bcast_b[i] = output_shape[i];
148 }
149 else
150 {
151 bcast_b[i] = 1;
152 }
153 }
154
155 return 0;
156}
157
158template <int Rank, DType InDtype, DType OutDtype>
159int BinaryNode<Rank, InDtype, OutDtype>::eval()
160{
161 this->broadcast();
162
163 Eigen::array<int, Rank> reshaper;
164 reshaper.fill(1);
165 TIn ia, ib;
166
167 if (this->a_rank == this->max_input_rank)
168 {
169 ia = this->a->getTensor().broadcast(this->bcast_a);
170 }
171 else
172 {
173 ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a);
174 }
175
176 if (this->b_rank == this->max_input_rank)
177 {
178 ib = this->b->getTensor().broadcast(this->bcast_b);
179 }
180 else
181 {
182 ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b);
183 }
184
185 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
186
187 return GraphNode::eval();
188}
189
190// still need to partial specialize this, or Eigen will throw static assertion
191template <DType InDtype, DType OutDtype>
192int BinaryNode<0, InDtype, OutDtype>::eval()
193{
194 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
195
196 return GraphNode::eval();
197}
198
199template <int Rank, DType Dtype>
200int OpAdd<Rank, Dtype>::register_fcn()
201{
202 switch (InDtype)
203 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700204 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100205 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
206 int64_t res_in_64 = static_cast<int64_t>(a) + b;
207 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
208 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
209 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
210 return static_cast<InEigenType>(res_in_64);
211 };
212 break;
213 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700214 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
215 break;
216 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700217 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700218 }
219
220 return 0;
221}
222
223template <int Rank, DType Dtype>
224int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
225{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800226 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700227 int32_t num_bits = 0;
228 switch (Dtype)
229 {
230 case DType_INT8:
231 num_bits = 8;
232 break;
233 case DType_INT16:
234 num_bits = 16;
235 break;
236 case DType_INT32:
237 num_bits = 32;
238 break;
239 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700240 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700241 }
242
Kevin Chengaee1fac2020-11-11 13:54:06 -0800243 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700244 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
245 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800246
247 InEigenType acc = a >> b;
248
249 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
250 {
251 acc++;
252 }
253
254 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700255 };
256
257 return 0;
258}
259
260template <int Rank, DType Dtype>
261int OpBitwiseAnd<Rank, Dtype>::register_fcn()
262{
263 switch (Dtype)
264 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800265 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700266 case DType_INT16:
267 case DType_INT32:
268 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
269 break;
270 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700271 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700272 }
273
274 return 0;
275}
276
277template <int Rank, DType Dtype>
278int OpBitwiseOr<Rank, Dtype>::register_fcn()
279{
280 switch (Dtype)
281 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800282 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700283 case DType_INT16:
284 case DType_INT32:
285 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
286 break;
287 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700288 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700289 }
290
291 return 0;
292}
293
294template <int Rank, DType Dtype>
295int OpBitwiseXor<Rank, Dtype>::register_fcn()
296{
297 switch (Dtype)
298 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800299 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700300 case DType_INT16:
301 case DType_INT32:
302 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
303 break;
304 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700305 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700306 }
307
308 return 0;
309}
310
311template <int Rank, DType Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100312int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700313{
314 switch (InDtype)
315 {
316 case DType_INT32:
317 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100318 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700319 int64_t res_in_64 = static_cast<int64_t>(a) / b;
320 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100321 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
322 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700323 return static_cast<InEigenType>(res_in_64);
324 };
325 break;
326 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700327 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700328 }
329
330 return 0;
331}
332
333template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700334int OpLogicalAnd<Rank, Dtype>::register_fcn()
335{
336 switch (Dtype)
337 {
338 case DType_BOOL:
339 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
340 break;
341 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700342 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700343 }
344
345 return 0;
346}
347
348template <int Rank, DType Dtype>
349int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
350{
351 switch (Dtype)
352 {
353 case DType_INT8:
354 case DType_INT16:
355 case DType_INT32:
356 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
357 break;
358 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700359 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700360 }
361
362 return 0;
363}
364
365template <int Rank, DType Dtype>
366int OpLogicalRightShift<Rank, Dtype>::register_fcn()
367{
368 int32_t num_bits = 0;
369 switch (Dtype)
370 {
371 case DType_INT8:
372 num_bits = 8;
373 break;
374 case DType_INT16:
375 num_bits = 16;
376 break;
377 case DType_INT32:
378 num_bits = 32;
379 break;
380 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700381 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700382 }
383
384 this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
385 uint32_t mask = ONES_MASK(num_bits) >> b;
386 return (a >> b) & mask;
387 };
388
389 return 0;
390}
391
392template <int Rank, DType Dtype>
393int OpLogicalOr<Rank, Dtype>::register_fcn()
394{
395 switch (Dtype)
396 {
397 case DType_BOOL:
398 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
399 break;
400 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700401 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700402 }
403
404 return 0;
405}
406
407template <int Rank, DType Dtype>
408int OpLogicalXor<Rank, Dtype>::register_fcn()
409{
410 switch (Dtype)
411 {
412 case DType_BOOL:
413 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
414 break;
415 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700416 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700417 }
418
419 return 0;
420}
421
422template <int Rank, DType Dtype>
423int OpMaximum<Rank, Dtype>::register_fcn()
424{
425 switch (Dtype)
426 {
427 case DType_FLOAT:
428 case DType_INT32:
429 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
430 break;
431 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700432 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700433 }
434
435 return 0;
436}
437
438template <int Rank, DType Dtype>
439int OpMinimum<Rank, Dtype>::register_fcn()
440{
441 switch (Dtype)
442 {
443 case DType_FLOAT:
444 case DType_INT32:
445 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
446 break;
447 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700448 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700449 }
450
451 return 0;
452}
453
454template <int Rank, DType InDtype, DType OutDtype>
455int OpMul<Rank, InDtype, OutDtype>::register_fcn()
456{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800457 int32_t shift = attribute->shift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800458
Eric Kunzee5e26762020-10-13 16:11:07 -0700459 switch (InDtype)
460 {
461 case DType_FLOAT:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800462 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
463 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700464 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800465 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
466 int64_t result;
467 if (shift > 0)
468 {
469 int64_t round = 1L << (shift - 1);
Kevin Cheng9257fd52021-04-14 15:55:31 -0700470 result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800471 result = result >> shift;
472
Kevin Chengacb550f2021-06-29 15:32:19 -0700473 REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
474 result, QMin, QMax);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800475 }
476 else
477 {
Jeremy Johnson90347472021-09-06 12:04:07 +0100478 result = static_cast<int64_t>(a) * b;
479 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
480 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
481 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
482 return static_cast<InEigenType>(result);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800483 }
484
485 return static_cast<OutEigenType>(result);
486 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700487 break;
488 case DType_INT8:
489 case DType_INT16:
490 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
491 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
492
493 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
494
495 return clamped_output;
496 };
497 break;
498 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700499 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700500 }
501
502 return 0;
503}
504
505template <int Rank, DType Dtype>
506int OpPow<Rank, Dtype>::register_fcn()
507{
508 switch (Dtype)
509 {
510 case DType_FLOAT:
511 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
512 break;
513 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700514 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700515 }
516
517 return 0;
518}
519
520template <int Rank, DType Dtype>
521int OpSub<Rank, Dtype>::register_fcn()
522{
523 switch (InDtype)
524 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100526 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
527 int64_t res_in_64 = static_cast<int64_t>(a) - b;
528 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
529 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
530 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
531 return static_cast<InEigenType>(res_in_64);
532 };
533 break;
534 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700535 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
536 break;
537 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700538 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700539 }
540
541 return 0;
542}
543
Kevin Cheng571f7182021-05-24 17:20:01 -0700544template <int Rank, DType InDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700545OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
546 TosaAttributeBase* attribute_,
547 TosaQuantInfoBase* qinfo_,
548 uint64_t id_)
549 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700550{
551 setRequiredOperands(2, 1);
552 setRequiredRank(0, 6);
553}
554
Kevin Cheng571f7182021-05-24 17:20:01 -0700555template <int Rank, DType InDtype>
556OpTable<Rank, InDtype>::~OpTable()
Eric Kunzee5e26762020-10-13 16:11:07 -0700557{}
558
Kevin Cheng571f7182021-05-24 17:20:01 -0700559template <int Rank, DType InDtype>
560int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700561{
562 if (validateRequiredOperands())
563 return 1;
564
565 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
566 {
567 return 1;
568 }
569
Kevin Cheng571f7182021-05-24 17:20:01 -0700570 if (inputs[1]->getRank() != 1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700571 {
Kevin Cheng571f7182021-05-24 17:20:01 -0700572 printNodeValidationError("OpTable: Table must be rank 1 tensor");
Eric Kunzee5e26762020-10-13 16:11:07 -0700573 return 1;
574 }
575
Kevin Cheng571f7182021-05-24 17:20:01 -0700576 if (inputs[0]->getDtype() == DType_INT8)
577 {
578 if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8)
579 {
580 printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
581 return 1;
582 }
583 }
584 else if (inputs[0]->getDtype() == DType_INT16)
585 {
586 if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
587 {
588 printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
589 return 1;
590 }
591 }
592
Eric Kunzee5e26762020-10-13 16:11:07 -0700593 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
594 table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
595 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
596
597 ASSERT_MEM(in && table && out);
598
599 return 0;
600}
601
Kevin Cheng571f7182021-05-24 17:20:01 -0700602template <int Rank, DType InDtype>
603int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700604{
Kevin Cheng571f7182021-05-24 17:20:01 -0700605 switch (InDtype)
606 {
607 case DType_INT8:
608 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
609 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
610 int32_t index = input_truncated - QInMin;
611 int32_t value = this->table->getTensor()(index);
Eric Kunzee5e26762020-10-13 16:11:07 -0700612
Kevin Cheng571f7182021-05-24 17:20:01 -0700613 return value;
614 });
615 break;
616 case DType_INT16:
617 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
618 // 1. make sure input is int16 range
619 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700620
Kevin Cheng571f7182021-05-24 17:20:01 -0700621 // 2. calculate index and interpolation fraction
622 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
623 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
624 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700625
Kevin Cheng571f7182021-05-24 17:20:01 -0700626 // 3. interpolate, generate 16.7 (23-bit) output
627 int32_t base = this->table->getTensor()(index);
628 int32_t next = this->table->getTensor()(index + 1);
629 int32_t value = (base << 7) + (next - base) * frac;
630
631 return value;
632 });
633 break;
634 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700635 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700636 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700637
638 return GraphNode::eval();
639}
640
641// template explicit instantiation
642DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
643DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
644
645DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
646DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
647DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
648
Kevin Cheng3a478572021-01-22 17:21:02 -0800649DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700650DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
651DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
652
Kevin Cheng3a478572021-01-22 17:21:02 -0800653DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700654DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
655DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
656
Kevin Cheng3a478572021-01-22 17:21:02 -0800657DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700658DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
659DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
660
Matthew Haddon459443c2021-08-23 16:43:13 +0100661DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700662
Eric Kunzee5e26762020-10-13 16:11:07 -0700663DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
664
665DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
666DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
667DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
668
669DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
670DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
671DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
672
673DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
674
675DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
676
677DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
678DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
679
680DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
681DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
682
683DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
684DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
685DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
686DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
687
688DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
689
690DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
691DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
692
Kevin Cheng571f7182021-05-24 17:20:01 -0700693DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
694DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700695
696DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
697DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);