blob: a11d855e9d5983f1486e5e2ae9d159244025aabc [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
27 const Op& op_,
28 TosaQuantInfoBase* qinfo_,
29 uint64_t id_)
30 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070031{
32 setRequiredOperands(2, 1);
33 setRequiredRank(0, 6);
34
Kevin Chengc42addc2021-09-28 15:41:57 -070035 a = b = nullptr;
36 result = nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -070037
38 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
39}
40
41template <int Rank, DType InDtype, DType OutDtype>
42BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
43{}
44
45template <int Rank, DType InDtype, DType OutDtype>
46int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
47{
48 if (validateRequiredOperands())
49 return 1;
50
51 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
52 {
53 return 1;
54 }
55
Kevin Chengc42addc2021-09-28 15:41:57 -070056 // A & B must be the same rank and types
57 if (inputs[0]->matchRankType(*inputs[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -070058 {
59 printNodeValidationError("Binary operator input types must match");
60 return 1;
61 }
62
Kevin Chengc42addc2021-09-28 15:41:57 -070063 // Input and output rank must match
64 // If it's not MUL, type also needs to match as well.
65 if (nodeType != Op_MUL)
Eric Kunzee5e26762020-10-13 16:11:07 -070066 {
Kevin Chengc42addc2021-09-28 15:41:57 -070067 if (inputs[0]->matchRankType(*outputs[0]))
68 {
69 printNodeValidationError("Binary operators (except MUL) input and output rank and type must match");
70 return 1;
71 }
Eric Kunzee5e26762020-10-13 16:11:07 -070072 }
73 else
74 {
Kevin Chengc42addc2021-09-28 15:41:57 -070075 if (inputs[0]->matchRank(*outputs[0]))
76 {
77 printNodeValidationError("MUL operator input and output rank must match");
78 return 1;
79 }
Eric Kunzee5e26762020-10-13 16:11:07 -070080 }
81
Kevin Chengc42addc2021-09-28 15:41:57 -070082 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
83 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
Eric Kunzee5e26762020-10-13 16:11:07 -070084 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
85
Kevin Chengc42addc2021-09-28 15:41:57 -070086 ASSERT_MEM(a && b && result);
Eric Kunzee5e26762020-10-13 16:11:07 -070087
88 return 0;
89}
90
91template <int Rank, DType InDtype, DType OutDtype>
92int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
93{
94 auto output_shape = result->getTensor().dimensions();
95
96 std::vector<int> a_shape, b_shape;
97
Kevin Chengc42addc2021-09-28 15:41:57 -070098 a_shape = a->getShape();
99 b_shape = b->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -0700100
Kevin Chengc42addc2021-09-28 15:41:57 -0700101 for (int i = 0; i < (int)a_shape.size(); i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 {
103 if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
104 {
105 bcast_a[i] = output_shape[i];
106 }
107 else
108 {
109 bcast_a[i] = 1;
110 }
111 if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
112 {
113 bcast_b[i] = output_shape[i];
114 }
115 else
116 {
117 bcast_b[i] = 1;
118 }
119 }
120
121 return 0;
122}
123
124template <int Rank, DType InDtype, DType OutDtype>
125int BinaryNode<Rank, InDtype, OutDtype>::eval()
126{
127 this->broadcast();
128
129 Eigen::array<int, Rank> reshaper;
130 reshaper.fill(1);
131 TIn ia, ib;
132
Kevin Chengc42addc2021-09-28 15:41:57 -0700133 ia = this->a->getTensor().broadcast(this->bcast_a);
134 ib = this->b->getTensor().broadcast(this->bcast_b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700135
136 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
137
138 return GraphNode::eval();
139}
140
141// still need to partial specialize this, or Eigen will throw static assertion
142template <DType InDtype, DType OutDtype>
143int BinaryNode<0, InDtype, OutDtype>::eval()
144{
145 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
146
147 return GraphNode::eval();
148}
149
150template <int Rank, DType Dtype>
151int OpAdd<Rank, Dtype>::register_fcn()
152{
153 switch (InDtype)
154 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700155 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100156 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
157 int64_t res_in_64 = static_cast<int64_t>(a) + b;
158 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
159 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
160 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
161 return static_cast<InEigenType>(res_in_64);
162 };
163 break;
164 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700165 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
166 break;
167 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700168 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700169 }
170
171 return 0;
172}
173
174template <int Rank, DType Dtype>
175int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
176{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800177 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 int32_t num_bits = 0;
179 switch (Dtype)
180 {
181 case DType_INT8:
182 num_bits = 8;
183 break;
184 case DType_INT16:
185 num_bits = 16;
186 break;
187 case DType_INT32:
188 num_bits = 32;
189 break;
190 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700191 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 }
193
Kevin Chengaee1fac2020-11-11 13:54:06 -0800194 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700195 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
196 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800197
198 InEigenType acc = a >> b;
199
200 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
201 {
202 acc++;
203 }
204
205 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 };
207
208 return 0;
209}
210
211template <int Rank, DType Dtype>
212int OpBitwiseAnd<Rank, Dtype>::register_fcn()
213{
214 switch (Dtype)
215 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800216 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700217 case DType_INT16:
218 case DType_INT32:
219 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
220 break;
221 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700222 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700223 }
224
225 return 0;
226}
227
228template <int Rank, DType Dtype>
229int OpBitwiseOr<Rank, Dtype>::register_fcn()
230{
231 switch (Dtype)
232 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800233 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700234 case DType_INT16:
235 case DType_INT32:
236 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
237 break;
238 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700239 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700240 }
241
242 return 0;
243}
244
245template <int Rank, DType Dtype>
246int OpBitwiseXor<Rank, Dtype>::register_fcn()
247{
248 switch (Dtype)
249 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800250 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700251 case DType_INT16:
252 case DType_INT32:
253 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
254 break;
255 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700256 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 }
258
259 return 0;
260}
261
262template <int Rank, DType Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100263int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700264{
265 switch (InDtype)
266 {
267 case DType_INT32:
268 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100269 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700270 int64_t res_in_64 = static_cast<int64_t>(a) / b;
271 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100272 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
273 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700274 return static_cast<InEigenType>(res_in_64);
275 };
276 break;
277 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700278 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700279 }
280
281 return 0;
282}
283
284template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700285int OpLogicalAnd<Rank, Dtype>::register_fcn()
286{
287 switch (Dtype)
288 {
289 case DType_BOOL:
290 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
291 break;
292 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700293 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700294 }
295
296 return 0;
297}
298
299template <int Rank, DType Dtype>
300int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
301{
302 switch (Dtype)
303 {
304 case DType_INT8:
305 case DType_INT16:
306 case DType_INT32:
307 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
308 break;
309 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700310 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700311 }
312
313 return 0;
314}
315
316template <int Rank, DType Dtype>
317int OpLogicalRightShift<Rank, Dtype>::register_fcn()
318{
319 int32_t num_bits = 0;
320 switch (Dtype)
321 {
322 case DType_INT8:
323 num_bits = 8;
324 break;
325 case DType_INT16:
326 num_bits = 16;
327 break;
328 case DType_INT32:
329 num_bits = 32;
330 break;
331 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700332 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700333 }
334
335 this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
336 uint32_t mask = ONES_MASK(num_bits) >> b;
337 return (a >> b) & mask;
338 };
339
340 return 0;
341}
342
343template <int Rank, DType Dtype>
344int OpLogicalOr<Rank, Dtype>::register_fcn()
345{
346 switch (Dtype)
347 {
348 case DType_BOOL:
349 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
350 break;
351 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700352 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700353 }
354
355 return 0;
356}
357
358template <int Rank, DType Dtype>
359int OpLogicalXor<Rank, Dtype>::register_fcn()
360{
361 switch (Dtype)
362 {
363 case DType_BOOL:
364 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
365 break;
366 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700367 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700368 }
369
370 return 0;
371}
372
373template <int Rank, DType Dtype>
374int OpMaximum<Rank, Dtype>::register_fcn()
375{
376 switch (Dtype)
377 {
378 case DType_FLOAT:
379 case DType_INT32:
380 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
381 break;
382 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700383 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700384 }
385
386 return 0;
387}
388
389template <int Rank, DType Dtype>
390int OpMinimum<Rank, Dtype>::register_fcn()
391{
392 switch (Dtype)
393 {
394 case DType_FLOAT:
395 case DType_INT32:
396 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
397 break;
398 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700399 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700400 }
401
402 return 0;
403}
404
405template <int Rank, DType InDtype, DType OutDtype>
406int OpMul<Rank, InDtype, OutDtype>::register_fcn()
407{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800408 int32_t shift = attribute->shift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800409
Eric Kunzee5e26762020-10-13 16:11:07 -0700410 switch (InDtype)
411 {
412 case DType_FLOAT:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800413 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
414 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700415 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800416 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
417 int64_t result;
418 if (shift > 0)
419 {
420 int64_t round = 1L << (shift - 1);
Kevin Cheng9257fd52021-04-14 15:55:31 -0700421 result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800422 result = result >> shift;
423
Kevin Chengacb550f2021-06-29 15:32:19 -0700424 REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
425 result, QMin, QMax);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800426 }
427 else
428 {
Kevin Chengc42addc2021-09-28 15:41:57 -0700429 result = static_cast<int64_t>(a) * b;
Jeremy Johnson90347472021-09-06 12:04:07 +0100430 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
431 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
432 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
433 return static_cast<InEigenType>(result);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800434 }
435
436 return static_cast<OutEigenType>(result);
437 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700438 break;
439 case DType_INT8:
440 case DType_INT16:
441 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
442 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
443
444 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
445
446 return clamped_output;
447 };
448 break;
449 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700450 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700451 }
452
453 return 0;
454}
455
456template <int Rank, DType Dtype>
457int OpPow<Rank, Dtype>::register_fcn()
458{
459 switch (Dtype)
460 {
461 case DType_FLOAT:
462 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
463 break;
464 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700465 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700466 }
467
468 return 0;
469}
470
471template <int Rank, DType Dtype>
472int OpSub<Rank, Dtype>::register_fcn()
473{
474 switch (InDtype)
475 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700476 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100477 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
478 int64_t res_in_64 = static_cast<int64_t>(a) - b;
479 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
480 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
481 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
482 return static_cast<InEigenType>(res_in_64);
483 };
484 break;
485 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700486 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
487 break;
488 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700489 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700490 }
491
492 return 0;
493}
494
Kevin Cheng571f7182021-05-24 17:20:01 -0700495template <int Rank, DType InDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700496OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
497 TosaAttributeBase* attribute_,
498 TosaQuantInfoBase* qinfo_,
499 uint64_t id_)
500 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700501{
502 setRequiredOperands(2, 1);
503 setRequiredRank(0, 6);
504}
505
Kevin Cheng571f7182021-05-24 17:20:01 -0700506template <int Rank, DType InDtype>
507OpTable<Rank, InDtype>::~OpTable()
Eric Kunzee5e26762020-10-13 16:11:07 -0700508{}
509
Kevin Cheng571f7182021-05-24 17:20:01 -0700510template <int Rank, DType InDtype>
511int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700512{
513 if (validateRequiredOperands())
514 return 1;
515
516 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
517 {
518 return 1;
519 }
520
Kevin Cheng571f7182021-05-24 17:20:01 -0700521 if (inputs[1]->getRank() != 1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700522 {
Kevin Cheng571f7182021-05-24 17:20:01 -0700523 printNodeValidationError("OpTable: Table must be rank 1 tensor");
Eric Kunzee5e26762020-10-13 16:11:07 -0700524 return 1;
525 }
526
Kevin Cheng571f7182021-05-24 17:20:01 -0700527 if (inputs[0]->getDtype() == DType_INT8)
528 {
529 if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8)
530 {
531 printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
532 return 1;
533 }
534 }
535 else if (inputs[0]->getDtype() == DType_INT16)
536 {
537 if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
538 {
539 printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
540 return 1;
541 }
542 }
543
Eric Kunzee5e26762020-10-13 16:11:07 -0700544 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
545 table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
546 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
547
548 ASSERT_MEM(in && table && out);
549
550 return 0;
551}
552
Kevin Cheng571f7182021-05-24 17:20:01 -0700553template <int Rank, DType InDtype>
554int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700555{
Kevin Cheng571f7182021-05-24 17:20:01 -0700556 switch (InDtype)
557 {
558 case DType_INT8:
559 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
560 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
561 int32_t index = input_truncated - QInMin;
562 int32_t value = this->table->getTensor()(index);
Eric Kunzee5e26762020-10-13 16:11:07 -0700563
Kevin Cheng571f7182021-05-24 17:20:01 -0700564 return value;
565 });
566 break;
567 case DType_INT16:
568 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
569 // 1. make sure input is int16 range
570 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700571
Kevin Cheng571f7182021-05-24 17:20:01 -0700572 // 2. calculate index and interpolation fraction
573 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
574 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
575 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700576
Kevin Cheng571f7182021-05-24 17:20:01 -0700577 // 3. interpolate, generate 16.7 (23-bit) output
578 int32_t base = this->table->getTensor()(index);
579 int32_t next = this->table->getTensor()(index + 1);
580 int32_t value = (base << 7) + (next - base) * frac;
581
582 return value;
583 });
584 break;
585 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700586 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700587 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700588
589 return GraphNode::eval();
590}
591
592// template explicit instantiation
593DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
594DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
595
596DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
597DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
598DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
599
Kevin Cheng3a478572021-01-22 17:21:02 -0800600DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700601DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
602DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
603
Kevin Cheng3a478572021-01-22 17:21:02 -0800604DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700605DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
606DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
607
Kevin Cheng3a478572021-01-22 17:21:02 -0800608DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700609DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
610DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
611
Matthew Haddon459443c2021-08-23 16:43:13 +0100612DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700613
Eric Kunzee5e26762020-10-13 16:11:07 -0700614DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
615
616DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
617DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
618DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
619
620DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
621DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
622DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
623
624DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
625
626DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
627
628DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
629DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
630
631DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
632DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
633
634DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
635DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
636DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
637DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
638
639DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
640
641DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
642DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
643
Kevin Cheng571f7182021-05-24 17:20:01 -0700644DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
645DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700646
647DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
648DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);