blob: 917d56e7e20744f41de37e356f737610df0d40a4 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
27 const Op& op_,
Kevin Chengacb550f2021-06-29 15:32:19 -070028 uint64_t id_)
29 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070030{
31 setRequiredOperands(2, 1);
32 setRequiredRank(0, 6);
33
Kevin Chengc42addc2021-09-28 15:41:57 -070034 a = b = nullptr;
35 result = nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -070036
37 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
38}
39
40template <int Rank, DType InDtype, DType OutDtype>
41BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
42{}
43
44template <int Rank, DType InDtype, DType OutDtype>
45int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
46{
47 if (validateRequiredOperands())
48 return 1;
49
50 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
51 {
52 return 1;
53 }
54
Kevin Chengc42addc2021-09-28 15:41:57 -070055 // A & B must be the same rank and types
56 if (inputs[0]->matchRankType(*inputs[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -070057 {
58 printNodeValidationError("Binary operator input types must match");
59 return 1;
60 }
61
Kevin Cheng1c3c8472021-11-08 11:19:10 -080062 if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */))
Kevin Cheng478101b2021-10-04 10:43:14 -070063 {
64 std::string err =
Kevin Cheng1c3c8472021-11-08 11:19:10 -080065 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match";
66 printNodeValidationError(err.c_str());
67 return 1;
68 }
69
70 if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */))
71 {
72 std::string err =
73 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match";
Kevin Cheng478101b2021-10-04 10:43:14 -070074 printNodeValidationError(err.c_str());
75 return 1;
76 }
Eric Kunzee5e26762020-10-13 16:11:07 -070077
Kevin Chengcc61be32021-10-14 17:09:57 -070078 ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
79
Kevin Chengc42addc2021-09-28 15:41:57 -070080 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
81 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
Eric Kunzee5e26762020-10-13 16:11:07 -070082 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
83
Kevin Chengc42addc2021-09-28 15:41:57 -070084 ASSERT_MEM(a && b && result);
Eric Kunzee5e26762020-10-13 16:11:07 -070085
86 return 0;
87}
88
89template <int Rank, DType InDtype, DType OutDtype>
90int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
91{
Kevin Cheng1c3c8472021-11-08 11:19:10 -080092 const std::vector<int>& a_shape = a->getShape();
93 const std::vector<int>& b_shape = b->getShape();
94 const std::vector<int>& output_shape = result->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070095
Kevin Cheng1c3c8472021-11-08 11:19:10 -080096 for (int i = 0; i < Rank; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070097 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080098 bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
99 bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 }
101
102 return 0;
103}
104
105template <int Rank, DType InDtype, DType OutDtype>
106int BinaryNode<Rank, InDtype, OutDtype>::eval()
107{
108 this->broadcast();
109
110 Eigen::array<int, Rank> reshaper;
111 reshaper.fill(1);
112 TIn ia, ib;
113
Kevin Chengc42addc2021-09-28 15:41:57 -0700114 ia = this->a->getTensor().broadcast(this->bcast_a);
115 ib = this->b->getTensor().broadcast(this->bcast_b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700116
117 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
118
119 return GraphNode::eval();
120}
121
122// still need to partial specialize this, or Eigen will throw static assertion
123template <DType InDtype, DType OutDtype>
124int BinaryNode<0, InDtype, OutDtype>::eval()
125{
126 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
127
128 return GraphNode::eval();
129}
130
131template <int Rank, DType Dtype>
132int OpAdd<Rank, Dtype>::register_fcn()
133{
134 switch (InDtype)
135 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700136 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100137 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
138 int64_t res_in_64 = static_cast<int64_t>(a) + b;
139 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
140 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
141 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
142 return static_cast<InEigenType>(res_in_64);
143 };
144 break;
James Ward8b390432022-08-12 20:48:56 +0100145 case DType_FP16:
Jeremy Johnson90347472021-09-06 12:04:07 +0100146 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700147 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
148 break;
149 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700150 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700151 }
152
153 return 0;
154}
155
156template <int Rank, DType Dtype>
157int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
158{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800159 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700160 int32_t num_bits = 0;
161 switch (Dtype)
162 {
163 case DType_INT8:
164 num_bits = 8;
165 break;
166 case DType_INT16:
167 num_bits = 16;
168 break;
169 case DType_INT32:
170 num_bits = 32;
171 break;
172 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700173 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700174 }
175
Kevin Chengaee1fac2020-11-11 13:54:06 -0800176 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700177 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
178 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800179
180 InEigenType acc = a >> b;
181
182 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
183 {
184 acc++;
185 }
186
187 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700188 };
189
190 return 0;
191}
192
193template <int Rank, DType Dtype>
194int OpBitwiseAnd<Rank, Dtype>::register_fcn()
195{
196 switch (Dtype)
197 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800198 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700199 case DType_INT16:
200 case DType_INT32:
201 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
202 break;
203 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700204 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700205 }
206
207 return 0;
208}
209
210template <int Rank, DType Dtype>
211int OpBitwiseOr<Rank, Dtype>::register_fcn()
212{
213 switch (Dtype)
214 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800215 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 case DType_INT16:
217 case DType_INT32:
218 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
219 break;
220 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700221 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700222 }
223
224 return 0;
225}
226
227template <int Rank, DType Dtype>
228int OpBitwiseXor<Rank, Dtype>::register_fcn()
229{
230 switch (Dtype)
231 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800232 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700233 case DType_INT16:
234 case DType_INT32:
235 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
236 break;
237 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700238 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700239 }
240
241 return 0;
242}
243
244template <int Rank, DType Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100245int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700246{
247 switch (InDtype)
248 {
249 case DType_INT32:
250 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100251 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700252 int64_t res_in_64 = static_cast<int64_t>(a) / b;
253 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100254 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
255 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700256 return static_cast<InEigenType>(res_in_64);
257 };
258 break;
259 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700260 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700261 }
262
263 return 0;
264}
265
266template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700267int OpLogicalAnd<Rank, Dtype>::register_fcn()
268{
269 switch (Dtype)
270 {
271 case DType_BOOL:
272 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
273 break;
274 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700275 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700276 }
277
278 return 0;
279}
280
281template <int Rank, DType Dtype>
282int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
283{
Jeremy Johnson66bad802022-01-18 14:48:35 +0000284 int32_t num_bits = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700285 switch (Dtype)
286 {
287 case DType_INT8:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000288 num_bits = 8;
289 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700290 case DType_INT16:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000291 num_bits = 16;
292 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700293 case DType_INT32:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000294 num_bits = 32;
Eric Kunzee5e26762020-10-13 16:11:07 -0700295 break;
296 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700297 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700298 }
Jeremy Johnson66bad802022-01-18 14:48:35 +0000299 this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
300 uint32_t mask = ONES_MASK(num_bits);
301 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
302 (int32_t)b);
303 return (a << b) & mask;
304 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
306 return 0;
307}
308
309template <int Rank, DType Dtype>
310int OpLogicalRightShift<Rank, Dtype>::register_fcn()
311{
312 int32_t num_bits = 0;
313 switch (Dtype)
314 {
315 case DType_INT8:
316 num_bits = 8;
317 break;
318 case DType_INT16:
319 num_bits = 16;
320 break;
321 case DType_INT32:
322 num_bits = 32;
323 break;
324 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700325 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700326 }
327
Jeremy Johnson66bad802022-01-18 14:48:35 +0000328 this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Eric Kunzee5e26762020-10-13 16:11:07 -0700329 uint32_t mask = ONES_MASK(num_bits) >> b;
Jeremy Johnson66bad802022-01-18 14:48:35 +0000330 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
331 (int32_t)b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700332 return (a >> b) & mask;
333 };
334
335 return 0;
336}
337
338template <int Rank, DType Dtype>
339int OpLogicalOr<Rank, Dtype>::register_fcn()
340{
341 switch (Dtype)
342 {
343 case DType_BOOL:
344 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
345 break;
346 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700347 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700348 }
349
350 return 0;
351}
352
353template <int Rank, DType Dtype>
354int OpLogicalXor<Rank, Dtype>::register_fcn()
355{
356 switch (Dtype)
357 {
358 case DType_BOOL:
359 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
360 break;
361 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700362 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700363 }
364
365 return 0;
366}
367
368template <int Rank, DType Dtype>
369int OpMaximum<Rank, Dtype>::register_fcn()
370{
371 switch (Dtype)
372 {
James Ward8b390432022-08-12 20:48:56 +0100373 case DType_FP16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700374 case DType_FLOAT:
375 case DType_INT32:
376 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
377 break;
378 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700379 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700380 }
381
382 return 0;
383}
384
385template <int Rank, DType Dtype>
386int OpMinimum<Rank, Dtype>::register_fcn()
387{
388 switch (Dtype)
389 {
James Ward8b390432022-08-12 20:48:56 +0100390 case DType_FP16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 case DType_FLOAT:
392 case DType_INT32:
393 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
394 break;
395 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700396 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700397 }
398
399 return 0;
400}
401
402template <int Rank, DType InDtype, DType OutDtype>
403int OpMul<Rank, InDtype, OutDtype>::register_fcn()
404{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800405 int32_t shift = attribute->shift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800406
Eric Kunzee5e26762020-10-13 16:11:07 -0700407 switch (InDtype)
408 {
James Ward8b390432022-08-12 20:48:56 +0100409 case DType_FP16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700410 case DType_FLOAT:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800411 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
412 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700413 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800414 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
415 int64_t result;
416 if (shift > 0)
417 {
418 int64_t round = 1L << (shift - 1);
Kevin Cheng9257fd52021-04-14 15:55:31 -0700419 result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800420 result = result >> shift;
421
Kevin Chengacb550f2021-06-29 15:32:19 -0700422 REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
423 result, QMin, QMax);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800424 }
425 else
426 {
Kevin Chengc42addc2021-09-28 15:41:57 -0700427 result = static_cast<int64_t>(a) * b;
Jeremy Johnson90347472021-09-06 12:04:07 +0100428 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
429 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
430 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
431 return static_cast<InEigenType>(result);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800432 }
433
434 return static_cast<OutEigenType>(result);
435 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700436 break;
437 case DType_INT8:
438 case DType_INT16:
439 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
440 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
441
442 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
443
444 return clamped_output;
445 };
446 break;
447 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700448 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700449 }
450
451 return 0;
452}
453
454template <int Rank, DType Dtype>
455int OpPow<Rank, Dtype>::register_fcn()
456{
457 switch (Dtype)
458 {
James Ward8b390432022-08-12 20:48:56 +0100459 case DType_FP16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700460 case DType_FLOAT:
461 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
462 break;
463 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700464 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700465 }
466
467 return 0;
468}
469
470template <int Rank, DType Dtype>
471int OpSub<Rank, Dtype>::register_fcn()
472{
473 switch (InDtype)
474 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700475 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100476 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
477 int64_t res_in_64 = static_cast<int64_t>(a) - b;
478 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
479 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
480 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
481 return static_cast<InEigenType>(res_in_64);
482 };
483 break;
James Ward8b390432022-08-12 20:48:56 +0100484 case DType_FP16:
Jeremy Johnson90347472021-09-06 12:04:07 +0100485 case DType_FLOAT:
Eric Kunzee5e26762020-10-13 16:11:07 -0700486 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
487 break;
488 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700489 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700490 }
491
492 return 0;
493}
494
Kevin Cheng571f7182021-05-24 17:20:01 -0700495template <int Rank, DType InDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700496OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
497 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700498 uint64_t id_)
499 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700500{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000501 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700502 setRequiredRank(0, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000503
504 INIT_ATTRIBUTE(Table);
Eric Kunzee5e26762020-10-13 16:11:07 -0700505}
506
Kevin Cheng571f7182021-05-24 17:20:01 -0700507template <int Rank, DType InDtype>
508OpTable<Rank, InDtype>::~OpTable()
Eric Kunzee5e26762020-10-13 16:11:07 -0700509{}
510
Kevin Cheng571f7182021-05-24 17:20:01 -0700511template <int Rank, DType InDtype>
512int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700513{
514 if (validateRequiredOperands())
515 return 1;
516
517 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
518 {
519 return 1;
520 }
521
Kevin Chengfe392ce2021-10-18 21:51:55 +0000522 ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000523 ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type");
Kevin Chengfe392ce2021-10-18 21:51:55 +0000524 ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries);
525
526 for (uint32_t i = 0; i < TableNumEntries; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700527 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000528 table[i] = (TableEigenType)attribute->table()[i];
Eric Kunzee5e26762020-10-13 16:11:07 -0700529 }
530
Kevin Chengfe392ce2021-10-18 21:51:55 +0000531 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
532 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700533
Kevin Chengfe392ce2021-10-18 21:51:55 +0000534 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700535
536 return 0;
537}
538
Kevin Cheng571f7182021-05-24 17:20:01 -0700539template <int Rank, DType InDtype>
540int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700541{
Kevin Cheng571f7182021-05-24 17:20:01 -0700542 switch (InDtype)
543 {
544 case DType_INT8:
545 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
546 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
547 int32_t index = input_truncated - QInMin;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000548 int32_t value = table[index];
Eric Kunzee5e26762020-10-13 16:11:07 -0700549
Kevin Cheng571f7182021-05-24 17:20:01 -0700550 return value;
551 });
552 break;
553 case DType_INT16:
554 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
555 // 1. make sure input is int16 range
556 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700557
Kevin Cheng571f7182021-05-24 17:20:01 -0700558 // 2. calculate index and interpolation fraction
559 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
560 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
561 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700562
Jerry Ged511f9e2022-08-12 16:12:40 -0700563 // 3. Add REQUIRE CHECK for extreme large/small slopes
Kevin Chengfe392ce2021-10-18 21:51:55 +0000564 int32_t base = table[index];
565 int32_t next = table[index + 1];
Jerry Ged511f9e2022-08-12 16:12:40 -0700566 int32_t slope = next - base;
567 REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(), "OpTable: slope out of int16_t range");
568
569 // 4. interpolate, generate 16.7 (23-bit) output
570 int32_t value = (base << 7) + (slope) * frac;
Kevin Cheng571f7182021-05-24 17:20:01 -0700571
572 return value;
573 });
574 break;
575 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700576 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700577 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700578
579 return GraphNode::eval();
580}
581
582// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +0100583DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700584DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
585DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
586
587DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
588DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
589DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
590
Kevin Cheng3a478572021-01-22 17:21:02 -0800591DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700592DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
593DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
594
Kevin Cheng3a478572021-01-22 17:21:02 -0800595DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700596DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
597DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
598
Kevin Cheng3a478572021-01-22 17:21:02 -0800599DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700600DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
601DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
602
Matthew Haddon459443c2021-08-23 16:43:13 +0100603DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700604
Eric Kunzee5e26762020-10-13 16:11:07 -0700605DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
606
607DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
608DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
609DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
610
611DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
612DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
613DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
614
615DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
616
617DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
618
James Ward8b390432022-08-12 20:48:56 +0100619DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700620DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
621DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
622
James Ward8b390432022-08-12 20:48:56 +0100623DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700624DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
625DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
626
James Ward8b390432022-08-12 20:48:56 +0100627DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700628DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
629DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
630DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
631DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
632
James Ward8b390432022-08-12 20:48:56 +0100633DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700634DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
635
James Ward8b390432022-08-12 20:48:56 +0100636DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700637DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
638DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
639
Kevin Cheng571f7182021-05-24 17:20:01 -0700640DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
641DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700642
James Ward8b390432022-08-12 20:48:56 +0100643// Instantiation of nodes for comparison operators opEqual, opGreater
644// and opGreaterEqual
645DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
Eric Kunzee5e26762020-10-13 16:11:07 -0700646DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
647DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);