blob: a5e1a20efb8a5e6c7b4b6e535fb7559bd0132953 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
27 const Op& op_,
Kevin Chengacb550f2021-06-29 15:32:19 -070028 uint64_t id_)
29 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070030{
31 setRequiredOperands(2, 1);
32 setRequiredRank(0, 6);
33
Kevin Chengc42addc2021-09-28 15:41:57 -070034 a = b = nullptr;
35 result = nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -070036
37 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
38}
39
40template <int Rank, DType InDtype, DType OutDtype>
41BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
42{}
43
44template <int Rank, DType InDtype, DType OutDtype>
45int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
46{
47 if (validateRequiredOperands())
48 return 1;
49
50 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
51 {
52 return 1;
53 }
54
Kevin Chengc42addc2021-09-28 15:41:57 -070055 // A & B must be the same rank and types
56 if (inputs[0]->matchRankType(*inputs[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -070057 {
58 printNodeValidationError("Binary operator input types must match");
59 return 1;
60 }
61
Kevin Cheng1c3c8472021-11-08 11:19:10 -080062 if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */))
Kevin Cheng478101b2021-10-04 10:43:14 -070063 {
64 std::string err =
Kevin Cheng1c3c8472021-11-08 11:19:10 -080065 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match";
66 printNodeValidationError(err.c_str());
67 return 1;
68 }
69
70 if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */))
71 {
72 std::string err =
73 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match";
Kevin Cheng478101b2021-10-04 10:43:14 -070074 printNodeValidationError(err.c_str());
75 return 1;
76 }
Eric Kunzee5e26762020-10-13 16:11:07 -070077
Kevin Chengcc61be32021-10-14 17:09:57 -070078 ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
79
Kevin Chengc42addc2021-09-28 15:41:57 -070080 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
81 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
Eric Kunzee5e26762020-10-13 16:11:07 -070082 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
83
Kevin Chengc42addc2021-09-28 15:41:57 -070084 ASSERT_MEM(a && b && result);
Eric Kunzee5e26762020-10-13 16:11:07 -070085
86 return 0;
87}
88
89template <int Rank, DType InDtype, DType OutDtype>
90int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
91{
Kevin Cheng1c3c8472021-11-08 11:19:10 -080092 const std::vector<int>& a_shape = a->getShape();
93 const std::vector<int>& b_shape = b->getShape();
94 const std::vector<int>& output_shape = result->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070095
Kevin Cheng1c3c8472021-11-08 11:19:10 -080096 for (int i = 0; i < Rank; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070097 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080098 bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
99 bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 }
101
102 return 0;
103}
104
105template <int Rank, DType InDtype, DType OutDtype>
106int BinaryNode<Rank, InDtype, OutDtype>::eval()
107{
108 this->broadcast();
109
110 Eigen::array<int, Rank> reshaper;
111 reshaper.fill(1);
112 TIn ia, ib;
113
Kevin Chengc42addc2021-09-28 15:41:57 -0700114 ia = this->a->getTensor().broadcast(this->bcast_a);
115 ib = this->b->getTensor().broadcast(this->bcast_b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700116
117 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
118
119 return GraphNode::eval();
120}
121
122// still need to partial specialize this, or Eigen will throw static assertion
123template <DType InDtype, DType OutDtype>
124int BinaryNode<0, InDtype, OutDtype>::eval()
125{
126 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
127
128 return GraphNode::eval();
129}
130
131template <int Rank, DType Dtype>
132int OpAdd<Rank, Dtype>::register_fcn()
133{
134 switch (InDtype)
135 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700136 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100137 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
138 int64_t res_in_64 = static_cast<int64_t>(a) + b;
139 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
140 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
141 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
142 return static_cast<InEigenType>(res_in_64);
143 };
144 break;
145 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700146 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
147 break;
148 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700149 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700150 }
151
152 return 0;
153}
154
155template <int Rank, DType Dtype>
156int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
157{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800158 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700159 int32_t num_bits = 0;
160 switch (Dtype)
161 {
162 case DType_INT8:
163 num_bits = 8;
164 break;
165 case DType_INT16:
166 num_bits = 16;
167 break;
168 case DType_INT32:
169 num_bits = 32;
170 break;
171 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700172 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 }
174
Kevin Chengaee1fac2020-11-11 13:54:06 -0800175 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700176 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
177 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800178
179 InEigenType acc = a >> b;
180
181 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
182 {
183 acc++;
184 }
185
186 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 };
188
189 return 0;
190}
191
192template <int Rank, DType Dtype>
193int OpBitwiseAnd<Rank, Dtype>::register_fcn()
194{
195 switch (Dtype)
196 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800197 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700198 case DType_INT16:
199 case DType_INT32:
200 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
201 break;
202 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700203 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700204 }
205
206 return 0;
207}
208
209template <int Rank, DType Dtype>
210int OpBitwiseOr<Rank, Dtype>::register_fcn()
211{
212 switch (Dtype)
213 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800214 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700215 case DType_INT16:
216 case DType_INT32:
217 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
218 break;
219 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700220 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700221 }
222
223 return 0;
224}
225
226template <int Rank, DType Dtype>
227int OpBitwiseXor<Rank, Dtype>::register_fcn()
228{
229 switch (Dtype)
230 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800231 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700232 case DType_INT16:
233 case DType_INT32:
234 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
235 break;
236 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700237 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700238 }
239
240 return 0;
241}
242
243template <int Rank, DType Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100244int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700245{
246 switch (InDtype)
247 {
248 case DType_INT32:
249 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100250 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700251 int64_t res_in_64 = static_cast<int64_t>(a) / b;
252 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100253 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
254 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700255 return static_cast<InEigenType>(res_in_64);
256 };
257 break;
258 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700259 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700260 }
261
262 return 0;
263}
264
265template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700266int OpLogicalAnd<Rank, Dtype>::register_fcn()
267{
268 switch (Dtype)
269 {
270 case DType_BOOL:
271 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
272 break;
273 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700274 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 }
276
277 return 0;
278}
279
280template <int Rank, DType Dtype>
281int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
282{
Jeremy Johnson66bad802022-01-18 14:48:35 +0000283 int32_t num_bits = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700284 switch (Dtype)
285 {
286 case DType_INT8:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000287 num_bits = 8;
288 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700289 case DType_INT16:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000290 num_bits = 16;
291 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 case DType_INT32:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000293 num_bits = 32;
Eric Kunzee5e26762020-10-13 16:11:07 -0700294 break;
295 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700296 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700297 }
Jeremy Johnson66bad802022-01-18 14:48:35 +0000298 this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
299 uint32_t mask = ONES_MASK(num_bits);
300 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
301 (int32_t)b);
302 return (a << b) & mask;
303 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700304
305 return 0;
306}
307
308template <int Rank, DType Dtype>
309int OpLogicalRightShift<Rank, Dtype>::register_fcn()
310{
311 int32_t num_bits = 0;
312 switch (Dtype)
313 {
314 case DType_INT8:
315 num_bits = 8;
316 break;
317 case DType_INT16:
318 num_bits = 16;
319 break;
320 case DType_INT32:
321 num_bits = 32;
322 break;
323 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700324 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700325 }
326
Jeremy Johnson66bad802022-01-18 14:48:35 +0000327 this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Eric Kunzee5e26762020-10-13 16:11:07 -0700328 uint32_t mask = ONES_MASK(num_bits) >> b;
Jeremy Johnson66bad802022-01-18 14:48:35 +0000329 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
330 (int32_t)b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700331 return (a >> b) & mask;
332 };
333
334 return 0;
335}
336
337template <int Rank, DType Dtype>
338int OpLogicalOr<Rank, Dtype>::register_fcn()
339{
340 switch (Dtype)
341 {
342 case DType_BOOL:
343 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
344 break;
345 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700346 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700347 }
348
349 return 0;
350}
351
352template <int Rank, DType Dtype>
353int OpLogicalXor<Rank, Dtype>::register_fcn()
354{
355 switch (Dtype)
356 {
357 case DType_BOOL:
358 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
359 break;
360 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700361 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700362 }
363
364 return 0;
365}
366
367template <int Rank, DType Dtype>
368int OpMaximum<Rank, Dtype>::register_fcn()
369{
370 switch (Dtype)
371 {
372 case DType_FLOAT:
373 case DType_INT32:
374 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
375 break;
376 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700377 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700378 }
379
380 return 0;
381}
382
383template <int Rank, DType Dtype>
384int OpMinimum<Rank, Dtype>::register_fcn()
385{
386 switch (Dtype)
387 {
388 case DType_FLOAT:
389 case DType_INT32:
390 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
391 break;
392 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700393 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700394 }
395
396 return 0;
397}
398
399template <int Rank, DType InDtype, DType OutDtype>
400int OpMul<Rank, InDtype, OutDtype>::register_fcn()
401{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800402 int32_t shift = attribute->shift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800403
Eric Kunzee5e26762020-10-13 16:11:07 -0700404 switch (InDtype)
405 {
406 case DType_FLOAT:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800407 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
408 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700409 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800410 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
411 int64_t result;
412 if (shift > 0)
413 {
414 int64_t round = 1L << (shift - 1);
Kevin Cheng9257fd52021-04-14 15:55:31 -0700415 result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800416 result = result >> shift;
417
Kevin Chengacb550f2021-06-29 15:32:19 -0700418 REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
419 result, QMin, QMax);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800420 }
421 else
422 {
Kevin Chengc42addc2021-09-28 15:41:57 -0700423 result = static_cast<int64_t>(a) * b;
Jeremy Johnson90347472021-09-06 12:04:07 +0100424 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
425 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
426 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
427 return static_cast<InEigenType>(result);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800428 }
429
430 return static_cast<OutEigenType>(result);
431 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700432 break;
433 case DType_INT8:
434 case DType_INT16:
435 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
436 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
437
438 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
439
440 return clamped_output;
441 };
442 break;
443 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700444 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700445 }
446
447 return 0;
448}
449
450template <int Rank, DType Dtype>
451int OpPow<Rank, Dtype>::register_fcn()
452{
453 switch (Dtype)
454 {
455 case DType_FLOAT:
456 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
457 break;
458 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700459 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700460 }
461
462 return 0;
463}
464
465template <int Rank, DType Dtype>
466int OpSub<Rank, Dtype>::register_fcn()
467{
468 switch (InDtype)
469 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700470 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100471 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
472 int64_t res_in_64 = static_cast<int64_t>(a) - b;
473 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
474 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
475 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
476 return static_cast<InEigenType>(res_in_64);
477 };
478 break;
479 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700480 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
481 break;
482 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700483 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700484 }
485
486 return 0;
487}
488
Kevin Cheng571f7182021-05-24 17:20:01 -0700489template <int Rank, DType InDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700490OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
491 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700492 uint64_t id_)
493 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700494{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000495 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700496 setRequiredRank(0, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000497
498 INIT_ATTRIBUTE(Table);
Eric Kunzee5e26762020-10-13 16:11:07 -0700499}
500
Kevin Cheng571f7182021-05-24 17:20:01 -0700501template <int Rank, DType InDtype>
502OpTable<Rank, InDtype>::~OpTable()
Eric Kunzee5e26762020-10-13 16:11:07 -0700503{}
504
Kevin Cheng571f7182021-05-24 17:20:01 -0700505template <int Rank, DType InDtype>
506int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700507{
508 if (validateRequiredOperands())
509 return 1;
510
511 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
512 {
513 return 1;
514 }
515
Kevin Chengfe392ce2021-10-18 21:51:55 +0000516 ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000517 ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type");
Kevin Chengfe392ce2021-10-18 21:51:55 +0000518 ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries);
519
520 for (uint32_t i = 0; i < TableNumEntries; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700521 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000522 table[i] = (TableEigenType)attribute->table()[i];
Eric Kunzee5e26762020-10-13 16:11:07 -0700523 }
524
Kevin Chengfe392ce2021-10-18 21:51:55 +0000525 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
526 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700527
Kevin Chengfe392ce2021-10-18 21:51:55 +0000528 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700529
530 return 0;
531}
532
Kevin Cheng571f7182021-05-24 17:20:01 -0700533template <int Rank, DType InDtype>
534int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700535{
Kevin Cheng571f7182021-05-24 17:20:01 -0700536 switch (InDtype)
537 {
538 case DType_INT8:
539 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
540 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
541 int32_t index = input_truncated - QInMin;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000542 int32_t value = table[index];
Eric Kunzee5e26762020-10-13 16:11:07 -0700543
Kevin Cheng571f7182021-05-24 17:20:01 -0700544 return value;
545 });
546 break;
547 case DType_INT16:
548 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
549 // 1. make sure input is int16 range
550 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700551
Kevin Cheng571f7182021-05-24 17:20:01 -0700552 // 2. calculate index and interpolation fraction
553 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
554 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
555 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700556
Jerry Ged511f9e2022-08-12 16:12:40 -0700557 // 3. Add REQUIRE CHECK for extreme large/small slopes
Kevin Chengfe392ce2021-10-18 21:51:55 +0000558 int32_t base = table[index];
559 int32_t next = table[index + 1];
Jerry Ged511f9e2022-08-12 16:12:40 -0700560 int32_t slope = next - base;
561 REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(), "OpTable: slope out of int16_t range");
562
563 // 4. interpolate, generate 16.7 (23-bit) output
564 int32_t value = (base << 7) + (slope) * frac;
Kevin Cheng571f7182021-05-24 17:20:01 -0700565
566 return value;
567 });
568 break;
569 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700570 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700571 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700572
573 return GraphNode::eval();
574}
575
576// template explicit instantiation
577DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
578DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
579
580DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
581DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
582DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
583
Kevin Cheng3a478572021-01-22 17:21:02 -0800584DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700585DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
586DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
587
Kevin Cheng3a478572021-01-22 17:21:02 -0800588DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700589DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
590DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
591
Kevin Cheng3a478572021-01-22 17:21:02 -0800592DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700593DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
594DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
595
Matthew Haddon459443c2021-08-23 16:43:13 +0100596DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700597
Eric Kunzee5e26762020-10-13 16:11:07 -0700598DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
599
600DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
601DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
602DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
603
604DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
605DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
606DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
607
608DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
609
610DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
611
612DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
613DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
614
615DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
616DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
617
618DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
619DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
620DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
621DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
622
623DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
624
625DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
626DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
627
Kevin Cheng571f7182021-05-24 17:20:01 -0700628DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
629DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700630
631DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
632DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);