blob: 287ad92b791d22c791ee3c44ddd6c53360bf7529 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
27 const Op& op_,
28 TosaQuantInfoBase* qinfo_,
29 uint64_t id_)
30 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070031{
32 setRequiredOperands(2, 1);
33 setRequiredRank(0, 6);
34
Kevin Chengc42addc2021-09-28 15:41:57 -070035 a = b = nullptr;
36 result = nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -070037
38 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
39}
40
41template <int Rank, DType InDtype, DType OutDtype>
42BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
43{}
44
45template <int Rank, DType InDtype, DType OutDtype>
46int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
47{
48 if (validateRequiredOperands())
49 return 1;
50
51 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
52 {
53 return 1;
54 }
55
Kevin Chengc42addc2021-09-28 15:41:57 -070056 // A & B must be the same rank and types
57 if (inputs[0]->matchRankType(*inputs[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -070058 {
59 printNodeValidationError("Binary operator input types must match");
60 return 1;
61 }
62
Kevin Cheng1c3c8472021-11-08 11:19:10 -080063 if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */))
Kevin Cheng478101b2021-10-04 10:43:14 -070064 {
65 std::string err =
Kevin Cheng1c3c8472021-11-08 11:19:10 -080066 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match";
67 printNodeValidationError(err.c_str());
68 return 1;
69 }
70
71 if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */))
72 {
73 std::string err =
74 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match";
Kevin Cheng478101b2021-10-04 10:43:14 -070075 printNodeValidationError(err.c_str());
76 return 1;
77 }
Eric Kunzee5e26762020-10-13 16:11:07 -070078
Kevin Chengcc61be32021-10-14 17:09:57 -070079 ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
80
Kevin Chengc42addc2021-09-28 15:41:57 -070081 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
82 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
Eric Kunzee5e26762020-10-13 16:11:07 -070083 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
84
Kevin Chengc42addc2021-09-28 15:41:57 -070085 ASSERT_MEM(a && b && result);
Eric Kunzee5e26762020-10-13 16:11:07 -070086
87 return 0;
88}
89
90template <int Rank, DType InDtype, DType OutDtype>
91int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
92{
Kevin Cheng1c3c8472021-11-08 11:19:10 -080093 const std::vector<int>& a_shape = a->getShape();
94 const std::vector<int>& b_shape = b->getShape();
95 const std::vector<int>& output_shape = result->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070096
Kevin Cheng1c3c8472021-11-08 11:19:10 -080097 for (int i = 0; i < Rank; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070098 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080099 bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
100 bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700101 }
102
103 return 0;
104}
105
106template <int Rank, DType InDtype, DType OutDtype>
107int BinaryNode<Rank, InDtype, OutDtype>::eval()
108{
109 this->broadcast();
110
111 Eigen::array<int, Rank> reshaper;
112 reshaper.fill(1);
113 TIn ia, ib;
114
Kevin Chengc42addc2021-09-28 15:41:57 -0700115 ia = this->a->getTensor().broadcast(this->bcast_a);
116 ib = this->b->getTensor().broadcast(this->bcast_b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
118 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
119
120 return GraphNode::eval();
121}
122
123// still need to partial specialize this, or Eigen will throw static assertion
124template <DType InDtype, DType OutDtype>
125int BinaryNode<0, InDtype, OutDtype>::eval()
126{
127 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
128
129 return GraphNode::eval();
130}
131
132template <int Rank, DType Dtype>
133int OpAdd<Rank, Dtype>::register_fcn()
134{
135 switch (InDtype)
136 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700137 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100138 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
139 int64_t res_in_64 = static_cast<int64_t>(a) + b;
140 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
141 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
142 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
143 return static_cast<InEigenType>(res_in_64);
144 };
145 break;
146 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700147 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
148 break;
149 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700150 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700151 }
152
153 return 0;
154}
155
156template <int Rank, DType Dtype>
157int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
158{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800159 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700160 int32_t num_bits = 0;
161 switch (Dtype)
162 {
163 case DType_INT8:
164 num_bits = 8;
165 break;
166 case DType_INT16:
167 num_bits = 16;
168 break;
169 case DType_INT32:
170 num_bits = 32;
171 break;
172 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700173 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700174 }
175
Kevin Chengaee1fac2020-11-11 13:54:06 -0800176 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700177 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
178 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800179
180 InEigenType acc = a >> b;
181
182 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
183 {
184 acc++;
185 }
186
187 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700188 };
189
190 return 0;
191}
192
193template <int Rank, DType Dtype>
194int OpBitwiseAnd<Rank, Dtype>::register_fcn()
195{
196 switch (Dtype)
197 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800198 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700199 case DType_INT16:
200 case DType_INT32:
201 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
202 break;
203 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700204 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700205 }
206
207 return 0;
208}
209
210template <int Rank, DType Dtype>
211int OpBitwiseOr<Rank, Dtype>::register_fcn()
212{
213 switch (Dtype)
214 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800215 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 case DType_INT16:
217 case DType_INT32:
218 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
219 break;
220 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700221 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700222 }
223
224 return 0;
225}
226
227template <int Rank, DType Dtype>
228int OpBitwiseXor<Rank, Dtype>::register_fcn()
229{
230 switch (Dtype)
231 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800232 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700233 case DType_INT16:
234 case DType_INT32:
235 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
236 break;
237 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700238 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700239 }
240
241 return 0;
242}
243
244template <int Rank, DType Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100245int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700246{
247 switch (InDtype)
248 {
249 case DType_INT32:
250 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100251 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700252 int64_t res_in_64 = static_cast<int64_t>(a) / b;
253 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100254 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
255 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700256 return static_cast<InEigenType>(res_in_64);
257 };
258 break;
259 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700260 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700261 }
262
263 return 0;
264}
265
266template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700267int OpLogicalAnd<Rank, Dtype>::register_fcn()
268{
269 switch (Dtype)
270 {
271 case DType_BOOL:
272 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
273 break;
274 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700275 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700276 }
277
278 return 0;
279}
280
281template <int Rank, DType Dtype>
282int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
283{
284 switch (Dtype)
285 {
286 case DType_INT8:
287 case DType_INT16:
288 case DType_INT32:
289 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
290 break;
291 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700292 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700293 }
294
295 return 0;
296}
297
298template <int Rank, DType Dtype>
299int OpLogicalRightShift<Rank, Dtype>::register_fcn()
300{
301 int32_t num_bits = 0;
302 switch (Dtype)
303 {
304 case DType_INT8:
305 num_bits = 8;
306 break;
307 case DType_INT16:
308 num_bits = 16;
309 break;
310 case DType_INT32:
311 num_bits = 32;
312 break;
313 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700314 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700315 }
316
317 this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
318 uint32_t mask = ONES_MASK(num_bits) >> b;
319 return (a >> b) & mask;
320 };
321
322 return 0;
323}
324
325template <int Rank, DType Dtype>
326int OpLogicalOr<Rank, Dtype>::register_fcn()
327{
328 switch (Dtype)
329 {
330 case DType_BOOL:
331 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
332 break;
333 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700334 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700335 }
336
337 return 0;
338}
339
340template <int Rank, DType Dtype>
341int OpLogicalXor<Rank, Dtype>::register_fcn()
342{
343 switch (Dtype)
344 {
345 case DType_BOOL:
346 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
347 break;
348 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700349 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700350 }
351
352 return 0;
353}
354
355template <int Rank, DType Dtype>
356int OpMaximum<Rank, Dtype>::register_fcn()
357{
358 switch (Dtype)
359 {
360 case DType_FLOAT:
361 case DType_INT32:
362 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
363 break;
364 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700365 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700366 }
367
368 return 0;
369}
370
371template <int Rank, DType Dtype>
372int OpMinimum<Rank, Dtype>::register_fcn()
373{
374 switch (Dtype)
375 {
376 case DType_FLOAT:
377 case DType_INT32:
378 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
379 break;
380 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700381 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700382 }
383
384 return 0;
385}
386
387template <int Rank, DType InDtype, DType OutDtype>
388int OpMul<Rank, InDtype, OutDtype>::register_fcn()
389{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800390 int32_t shift = attribute->shift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800391
Eric Kunzee5e26762020-10-13 16:11:07 -0700392 switch (InDtype)
393 {
394 case DType_FLOAT:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800395 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
396 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700397 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800398 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
399 int64_t result;
400 if (shift > 0)
401 {
402 int64_t round = 1L << (shift - 1);
Kevin Cheng9257fd52021-04-14 15:55:31 -0700403 result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800404 result = result >> shift;
405
Kevin Chengacb550f2021-06-29 15:32:19 -0700406 REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
407 result, QMin, QMax);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800408 }
409 else
410 {
Kevin Chengc42addc2021-09-28 15:41:57 -0700411 result = static_cast<int64_t>(a) * b;
Jeremy Johnson90347472021-09-06 12:04:07 +0100412 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
413 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
414 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
415 return static_cast<InEigenType>(result);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800416 }
417
418 return static_cast<OutEigenType>(result);
419 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700420 break;
421 case DType_INT8:
422 case DType_INT16:
423 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
424 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
425
426 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
427
428 return clamped_output;
429 };
430 break;
431 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700432 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700433 }
434
435 return 0;
436}
437
438template <int Rank, DType Dtype>
439int OpPow<Rank, Dtype>::register_fcn()
440{
441 switch (Dtype)
442 {
443 case DType_FLOAT:
444 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
445 break;
446 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700447 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700448 }
449
450 return 0;
451}
452
453template <int Rank, DType Dtype>
454int OpSub<Rank, Dtype>::register_fcn()
455{
456 switch (InDtype)
457 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700458 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100459 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
460 int64_t res_in_64 = static_cast<int64_t>(a) - b;
461 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
462 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
463 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
464 return static_cast<InEigenType>(res_in_64);
465 };
466 break;
467 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700468 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
469 break;
470 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700471 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700472 }
473
474 return 0;
475}
476
Kevin Cheng571f7182021-05-24 17:20:01 -0700477template <int Rank, DType InDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700478OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
479 TosaAttributeBase* attribute_,
480 TosaQuantInfoBase* qinfo_,
481 uint64_t id_)
482 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700483{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000484 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700485 setRequiredRank(0, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000486
487 INIT_ATTRIBUTE(Table);
Eric Kunzee5e26762020-10-13 16:11:07 -0700488}
489
Kevin Cheng571f7182021-05-24 17:20:01 -0700490template <int Rank, DType InDtype>
491OpTable<Rank, InDtype>::~OpTable()
Eric Kunzee5e26762020-10-13 16:11:07 -0700492{}
493
Kevin Cheng571f7182021-05-24 17:20:01 -0700494template <int Rank, DType InDtype>
495int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700496{
497 if (validateRequiredOperands())
498 return 1;
499
500 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
501 {
502 return 1;
503 }
504
Kevin Chengfe392ce2021-10-18 21:51:55 +0000505 ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000506 ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type");
Kevin Chengfe392ce2021-10-18 21:51:55 +0000507 ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries);
508
509 for (uint32_t i = 0; i < TableNumEntries; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700510 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000511 table[i] = (TableEigenType)attribute->table()[i];
Eric Kunzee5e26762020-10-13 16:11:07 -0700512 }
513
Kevin Chengfe392ce2021-10-18 21:51:55 +0000514 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
515 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700516
Kevin Chengfe392ce2021-10-18 21:51:55 +0000517 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700518
519 return 0;
520}
521
Kevin Cheng571f7182021-05-24 17:20:01 -0700522template <int Rank, DType InDtype>
523int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700524{
Kevin Cheng571f7182021-05-24 17:20:01 -0700525 switch (InDtype)
526 {
527 case DType_INT8:
528 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
529 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
530 int32_t index = input_truncated - QInMin;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000531 int32_t value = table[index];
Eric Kunzee5e26762020-10-13 16:11:07 -0700532
Kevin Cheng571f7182021-05-24 17:20:01 -0700533 return value;
534 });
535 break;
536 case DType_INT16:
537 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
538 // 1. make sure input is int16 range
539 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700540
Kevin Cheng571f7182021-05-24 17:20:01 -0700541 // 2. calculate index and interpolation fraction
542 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
543 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
544 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700545
Kevin Cheng571f7182021-05-24 17:20:01 -0700546 // 3. interpolate, generate 16.7 (23-bit) output
Kevin Chengfe392ce2021-10-18 21:51:55 +0000547 int32_t base = table[index];
548 int32_t next = table[index + 1];
Kevin Cheng571f7182021-05-24 17:20:01 -0700549 int32_t value = (base << 7) + (next - base) * frac;
550
551 return value;
552 });
553 break;
554 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700555 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700556 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700557
558 return GraphNode::eval();
559}
560
561// template explicit instantiation
562DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
563DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
564
565DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
566DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
567DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
568
Kevin Cheng3a478572021-01-22 17:21:02 -0800569DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700570DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
571DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
572
Kevin Cheng3a478572021-01-22 17:21:02 -0800573DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700574DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
575DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
576
Kevin Cheng3a478572021-01-22 17:21:02 -0800577DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700578DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
579DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
580
Matthew Haddon459443c2021-08-23 16:43:13 +0100581DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700582
Eric Kunzee5e26762020-10-13 16:11:07 -0700583DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
584
585DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
586DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
587DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
588
589DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
590DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
591DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
592
593DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
594
595DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
596
597DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
598DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
599
600DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
601DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
602
603DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
604DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
605DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
606DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
607
608DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
609
610DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
611DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
612
Kevin Cheng571f7182021-05-24 17:20:01 -0700613DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
614DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700615
616DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
617DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);