blob: 023158c0439f2c57c148a1bd0ac9623a08b5cee6 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
27 const Op& op_,
28 TosaQuantInfoBase* qinfo_,
29 uint64_t id_)
30 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070031{
32 setRequiredOperands(2, 1);
33 setRequiredRank(0, 6);
34
Kevin Chengc42addc2021-09-28 15:41:57 -070035 a = b = nullptr;
36 result = nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -070037
38 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
39}
40
41template <int Rank, DType InDtype, DType OutDtype>
42BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
43{}
44
45template <int Rank, DType InDtype, DType OutDtype>
46int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
47{
48 if (validateRequiredOperands())
49 return 1;
50
51 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
52 {
53 return 1;
54 }
55
Kevin Chengc42addc2021-09-28 15:41:57 -070056 // A & B must be the same rank and types
57 if (inputs[0]->matchRankType(*inputs[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -070058 {
59 printNodeValidationError("Binary operator input types must match");
60 return 1;
61 }
62
Kevin Cheng478101b2021-10-04 10:43:14 -070063 // In some ops, only rank of input and output tensor needs to match
64 if (nodeType == Op_MUL || nodeType == Op_GREATER || nodeType == Op_EQUAL || nodeType == Op_GREATER_EQUAL)
Eric Kunzee5e26762020-10-13 16:11:07 -070065 {
Kevin Chengc42addc2021-09-28 15:41:57 -070066 if (inputs[0]->matchRank(*outputs[0]))
67 {
Kevin Cheng478101b2021-10-04 10:43:14 -070068 std::string err =
69 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
70 printNodeValidationError(err.c_str());
Kevin Chengc42addc2021-09-28 15:41:57 -070071 return 1;
72 }
Eric Kunzee5e26762020-10-13 16:11:07 -070073 }
Kevin Cheng478101b2021-10-04 10:43:14 -070074 // Otherwise both rand/type of input and output must match
75 else if (inputs[0]->matchRankType(*outputs[0]))
76 {
77 std::string err =
78 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank and type must match";
79 printNodeValidationError(err.c_str());
80 return 1;
81 }
Eric Kunzee5e26762020-10-13 16:11:07 -070082
Kevin Chengc42addc2021-09-28 15:41:57 -070083 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
84 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
Eric Kunzee5e26762020-10-13 16:11:07 -070085 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
86
Kevin Chengc42addc2021-09-28 15:41:57 -070087 ASSERT_MEM(a && b && result);
Eric Kunzee5e26762020-10-13 16:11:07 -070088
89 return 0;
90}
91
92template <int Rank, DType InDtype, DType OutDtype>
93int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
94{
95 auto output_shape = result->getTensor().dimensions();
96
97 std::vector<int> a_shape, b_shape;
98
Kevin Chengc42addc2021-09-28 15:41:57 -070099 a_shape = a->getShape();
100 b_shape = b->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -0700101
Kevin Chengc42addc2021-09-28 15:41:57 -0700102 for (int i = 0; i < (int)a_shape.size(); i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700103 {
104 if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
105 {
106 bcast_a[i] = output_shape[i];
107 }
108 else
109 {
110 bcast_a[i] = 1;
111 }
112 if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
113 {
114 bcast_b[i] = output_shape[i];
115 }
116 else
117 {
118 bcast_b[i] = 1;
119 }
120 }
121
122 return 0;
123}
124
125template <int Rank, DType InDtype, DType OutDtype>
126int BinaryNode<Rank, InDtype, OutDtype>::eval()
127{
128 this->broadcast();
129
130 Eigen::array<int, Rank> reshaper;
131 reshaper.fill(1);
132 TIn ia, ib;
133
Kevin Chengc42addc2021-09-28 15:41:57 -0700134 ia = this->a->getTensor().broadcast(this->bcast_a);
135 ib = this->b->getTensor().broadcast(this->bcast_b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700136
137 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
138
139 return GraphNode::eval();
140}
141
142// still need to partial specialize this, or Eigen will throw static assertion
143template <DType InDtype, DType OutDtype>
144int BinaryNode<0, InDtype, OutDtype>::eval()
145{
146 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
147
148 return GraphNode::eval();
149}
150
151template <int Rank, DType Dtype>
152int OpAdd<Rank, Dtype>::register_fcn()
153{
154 switch (InDtype)
155 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700156 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100157 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
158 int64_t res_in_64 = static_cast<int64_t>(a) + b;
159 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
160 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
161 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
162 return static_cast<InEigenType>(res_in_64);
163 };
164 break;
165 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
167 break;
168 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700169 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700170 }
171
172 return 0;
173}
174
175template <int Rank, DType Dtype>
176int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
177{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800178 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 int32_t num_bits = 0;
180 switch (Dtype)
181 {
182 case DType_INT8:
183 num_bits = 8;
184 break;
185 case DType_INT16:
186 num_bits = 16;
187 break;
188 case DType_INT32:
189 num_bits = 32;
190 break;
191 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700192 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 }
194
Kevin Chengaee1fac2020-11-11 13:54:06 -0800195 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700196 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
197 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800198
199 InEigenType acc = a >> b;
200
201 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
202 {
203 acc++;
204 }
205
206 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 };
208
209 return 0;
210}
211
212template <int Rank, DType Dtype>
213int OpBitwiseAnd<Rank, Dtype>::register_fcn()
214{
215 switch (Dtype)
216 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800217 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700218 case DType_INT16:
219 case DType_INT32:
220 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
221 break;
222 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700223 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700224 }
225
226 return 0;
227}
228
229template <int Rank, DType Dtype>
230int OpBitwiseOr<Rank, Dtype>::register_fcn()
231{
232 switch (Dtype)
233 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800234 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700235 case DType_INT16:
236 case DType_INT32:
237 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
238 break;
239 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700240 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700241 }
242
243 return 0;
244}
245
246template <int Rank, DType Dtype>
247int OpBitwiseXor<Rank, Dtype>::register_fcn()
248{
249 switch (Dtype)
250 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800251 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700252 case DType_INT16:
253 case DType_INT32:
254 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
255 break;
256 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700257 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 }
259
260 return 0;
261}
262
263template <int Rank, DType Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100264int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700265{
266 switch (InDtype)
267 {
268 case DType_INT32:
269 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100270 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700271 int64_t res_in_64 = static_cast<int64_t>(a) / b;
272 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100273 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
274 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700275 return static_cast<InEigenType>(res_in_64);
276 };
277 break;
278 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700279 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700280 }
281
282 return 0;
283}
284
285template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700286int OpLogicalAnd<Rank, Dtype>::register_fcn()
287{
288 switch (Dtype)
289 {
290 case DType_BOOL:
291 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
292 break;
293 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700294 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700295 }
296
297 return 0;
298}
299
300template <int Rank, DType Dtype>
301int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
302{
303 switch (Dtype)
304 {
305 case DType_INT8:
306 case DType_INT16:
307 case DType_INT32:
308 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
309 break;
310 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700311 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700312 }
313
314 return 0;
315}
316
317template <int Rank, DType Dtype>
318int OpLogicalRightShift<Rank, Dtype>::register_fcn()
319{
320 int32_t num_bits = 0;
321 switch (Dtype)
322 {
323 case DType_INT8:
324 num_bits = 8;
325 break;
326 case DType_INT16:
327 num_bits = 16;
328 break;
329 case DType_INT32:
330 num_bits = 32;
331 break;
332 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700333 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700334 }
335
336 this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
337 uint32_t mask = ONES_MASK(num_bits) >> b;
338 return (a >> b) & mask;
339 };
340
341 return 0;
342}
343
344template <int Rank, DType Dtype>
345int OpLogicalOr<Rank, Dtype>::register_fcn()
346{
347 switch (Dtype)
348 {
349 case DType_BOOL:
350 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
351 break;
352 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700353 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700354 }
355
356 return 0;
357}
358
359template <int Rank, DType Dtype>
360int OpLogicalXor<Rank, Dtype>::register_fcn()
361{
362 switch (Dtype)
363 {
364 case DType_BOOL:
365 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
366 break;
367 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700368 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700369 }
370
371 return 0;
372}
373
374template <int Rank, DType Dtype>
375int OpMaximum<Rank, Dtype>::register_fcn()
376{
377 switch (Dtype)
378 {
379 case DType_FLOAT:
380 case DType_INT32:
381 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
382 break;
383 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700384 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700385 }
386
387 return 0;
388}
389
390template <int Rank, DType Dtype>
391int OpMinimum<Rank, Dtype>::register_fcn()
392{
393 switch (Dtype)
394 {
395 case DType_FLOAT:
396 case DType_INT32:
397 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
398 break;
399 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700400 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 }
402
403 return 0;
404}
405
406template <int Rank, DType InDtype, DType OutDtype>
407int OpMul<Rank, InDtype, OutDtype>::register_fcn()
408{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800409 int32_t shift = attribute->shift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800410
Eric Kunzee5e26762020-10-13 16:11:07 -0700411 switch (InDtype)
412 {
413 case DType_FLOAT:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800414 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
415 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700416 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800417 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
418 int64_t result;
419 if (shift > 0)
420 {
421 int64_t round = 1L << (shift - 1);
Kevin Cheng9257fd52021-04-14 15:55:31 -0700422 result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800423 result = result >> shift;
424
Kevin Chengacb550f2021-06-29 15:32:19 -0700425 REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
426 result, QMin, QMax);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800427 }
428 else
429 {
Kevin Chengc42addc2021-09-28 15:41:57 -0700430 result = static_cast<int64_t>(a) * b;
Jeremy Johnson90347472021-09-06 12:04:07 +0100431 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
432 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
433 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
434 return static_cast<InEigenType>(result);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800435 }
436
437 return static_cast<OutEigenType>(result);
438 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700439 break;
440 case DType_INT8:
441 case DType_INT16:
442 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
443 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
444
445 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
446
447 return clamped_output;
448 };
449 break;
450 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700451 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700452 }
453
454 return 0;
455}
456
457template <int Rank, DType Dtype>
458int OpPow<Rank, Dtype>::register_fcn()
459{
460 switch (Dtype)
461 {
462 case DType_FLOAT:
463 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
464 break;
465 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700466 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700467 }
468
469 return 0;
470}
471
472template <int Rank, DType Dtype>
473int OpSub<Rank, Dtype>::register_fcn()
474{
475 switch (InDtype)
476 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700477 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100478 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
479 int64_t res_in_64 = static_cast<int64_t>(a) - b;
480 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
481 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
482 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
483 return static_cast<InEigenType>(res_in_64);
484 };
485 break;
486 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700487 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
488 break;
489 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700490 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700491 }
492
493 return 0;
494}
495
Kevin Cheng571f7182021-05-24 17:20:01 -0700496template <int Rank, DType InDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700497OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
498 TosaAttributeBase* attribute_,
499 TosaQuantInfoBase* qinfo_,
500 uint64_t id_)
501 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700502{
503 setRequiredOperands(2, 1);
504 setRequiredRank(0, 6);
505}
506
Kevin Cheng571f7182021-05-24 17:20:01 -0700507template <int Rank, DType InDtype>
508OpTable<Rank, InDtype>::~OpTable()
Eric Kunzee5e26762020-10-13 16:11:07 -0700509{}
510
Kevin Cheng571f7182021-05-24 17:20:01 -0700511template <int Rank, DType InDtype>
512int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700513{
514 if (validateRequiredOperands())
515 return 1;
516
517 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
518 {
519 return 1;
520 }
521
Kevin Cheng571f7182021-05-24 17:20:01 -0700522 if (inputs[1]->getRank() != 1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700523 {
Kevin Cheng571f7182021-05-24 17:20:01 -0700524 printNodeValidationError("OpTable: Table must be rank 1 tensor");
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 return 1;
526 }
527
Kevin Cheng571f7182021-05-24 17:20:01 -0700528 if (inputs[0]->getDtype() == DType_INT8)
529 {
530 if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8)
531 {
532 printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
533 return 1;
534 }
535 }
536 else if (inputs[0]->getDtype() == DType_INT16)
537 {
538 if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
539 {
540 printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
541 return 1;
542 }
543 }
544
Eric Kunzee5e26762020-10-13 16:11:07 -0700545 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
546 table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
547 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
548
549 ASSERT_MEM(in && table && out);
550
551 return 0;
552}
553
Kevin Cheng571f7182021-05-24 17:20:01 -0700554template <int Rank, DType InDtype>
555int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700556{
Kevin Cheng571f7182021-05-24 17:20:01 -0700557 switch (InDtype)
558 {
559 case DType_INT8:
560 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
561 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
562 int32_t index = input_truncated - QInMin;
563 int32_t value = this->table->getTensor()(index);
Eric Kunzee5e26762020-10-13 16:11:07 -0700564
Kevin Cheng571f7182021-05-24 17:20:01 -0700565 return value;
566 });
567 break;
568 case DType_INT16:
569 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
570 // 1. make sure input is int16 range
571 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700572
Kevin Cheng571f7182021-05-24 17:20:01 -0700573 // 2. calculate index and interpolation fraction
574 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
575 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
576 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700577
Kevin Cheng571f7182021-05-24 17:20:01 -0700578 // 3. interpolate, generate 16.7 (23-bit) output
579 int32_t base = this->table->getTensor()(index);
580 int32_t next = this->table->getTensor()(index + 1);
581 int32_t value = (base << 7) + (next - base) * frac;
582
583 return value;
584 });
585 break;
586 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700587 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700588 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700589
590 return GraphNode::eval();
591}
592
593// template explicit instantiation
594DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
595DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
596
597DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
598DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
599DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
600
Kevin Cheng3a478572021-01-22 17:21:02 -0800601DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700602DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
603DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
604
Kevin Cheng3a478572021-01-22 17:21:02 -0800605DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700606DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
607DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
608
Kevin Cheng3a478572021-01-22 17:21:02 -0800609DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700610DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
611DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
612
Matthew Haddon459443c2021-08-23 16:43:13 +0100613DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700614
Eric Kunzee5e26762020-10-13 16:11:07 -0700615DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
616
617DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
618DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
619DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
620
621DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
622DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
623DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
624
625DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
626
627DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
628
629DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
630DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
631
632DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
633DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
634
635DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
636DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
637DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
638DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
639
640DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
641
642DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
643DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
644
Kevin Cheng571f7182021-05-24 17:20:01 -0700645DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
646DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700647
648DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
649DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);