blob: e4c0ee09a9bce9113a7728ad3595f2d850bb04ce [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
27 const Op& op_,
Kevin Chengacb550f2021-06-29 15:32:19 -070028 uint64_t id_)
29 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070030{
31 setRequiredOperands(2, 1);
32 setRequiredRank(0, 6);
33
Kevin Chengc42addc2021-09-28 15:41:57 -070034 a = b = nullptr;
35 result = nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -070036
37 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
38}
39
40template <int Rank, DType InDtype, DType OutDtype>
41BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
42{}
43
44template <int Rank, DType InDtype, DType OutDtype>
45int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
46{
47 if (validateRequiredOperands())
48 return 1;
49
50 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
51 {
52 return 1;
53 }
54
Kevin Chengc42addc2021-09-28 15:41:57 -070055 // A & B must be the same rank and types
56 if (inputs[0]->matchRankType(*inputs[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -070057 {
58 printNodeValidationError("Binary operator input types must match");
59 return 1;
60 }
61
Kevin Cheng1c3c8472021-11-08 11:19:10 -080062 if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */))
Kevin Cheng478101b2021-10-04 10:43:14 -070063 {
64 std::string err =
Kevin Cheng1c3c8472021-11-08 11:19:10 -080065 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match";
66 printNodeValidationError(err.c_str());
67 return 1;
68 }
69
70 if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */))
71 {
72 std::string err =
73 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match";
Kevin Cheng478101b2021-10-04 10:43:14 -070074 printNodeValidationError(err.c_str());
75 return 1;
76 }
Eric Kunzee5e26762020-10-13 16:11:07 -070077
Kevin Chengcc61be32021-10-14 17:09:57 -070078 ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
79
Kevin Chengc42addc2021-09-28 15:41:57 -070080 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
81 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
Eric Kunzee5e26762020-10-13 16:11:07 -070082 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
83
Kevin Chengc42addc2021-09-28 15:41:57 -070084 ASSERT_MEM(a && b && result);
Eric Kunzee5e26762020-10-13 16:11:07 -070085
86 return 0;
87}
88
89template <int Rank, DType InDtype, DType OutDtype>
90int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
91{
Kevin Cheng1c3c8472021-11-08 11:19:10 -080092 const std::vector<int>& a_shape = a->getShape();
93 const std::vector<int>& b_shape = b->getShape();
94 const std::vector<int>& output_shape = result->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070095
Kevin Cheng1c3c8472021-11-08 11:19:10 -080096 for (int i = 0; i < Rank; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070097 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080098 bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
99 bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 }
101
102 return 0;
103}
104
105template <int Rank, DType InDtype, DType OutDtype>
106int BinaryNode<Rank, InDtype, OutDtype>::eval()
107{
108 this->broadcast();
109
110 Eigen::array<int, Rank> reshaper;
111 reshaper.fill(1);
112 TIn ia, ib;
113
Kevin Chengc42addc2021-09-28 15:41:57 -0700114 ia = this->a->getTensor().broadcast(this->bcast_a);
115 ib = this->b->getTensor().broadcast(this->bcast_b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700116
117 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
118
119 return GraphNode::eval();
120}
121
122// still need to partial specialize this, or Eigen will throw static assertion
123template <DType InDtype, DType OutDtype>
124int BinaryNode<0, InDtype, OutDtype>::eval()
125{
126 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
127
128 return GraphNode::eval();
129}
130
131template <int Rank, DType Dtype>
132int OpAdd<Rank, Dtype>::register_fcn()
133{
134 switch (InDtype)
135 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700136 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100137 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
138 int64_t res_in_64 = static_cast<int64_t>(a) + b;
139 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
140 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
141 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
142 return static_cast<InEigenType>(res_in_64);
143 };
144 break;
James Ward8b390432022-08-12 20:48:56 +0100145 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100146 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100147 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100148 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a + b); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700149 break;
150 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700151 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700152 }
153
154 return 0;
155}
156
157template <int Rank, DType Dtype>
158int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
159{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800160 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700161 int32_t num_bits = 0;
162 switch (Dtype)
163 {
164 case DType_INT8:
165 num_bits = 8;
166 break;
167 case DType_INT16:
168 num_bits = 16;
169 break;
170 case DType_INT32:
171 num_bits = 32;
172 break;
173 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700174 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700175 }
176
Kevin Chengaee1fac2020-11-11 13:54:06 -0800177 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700178 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
179 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800180
181 InEigenType acc = a >> b;
182
183 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
184 {
185 acc++;
186 }
187
188 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700189 };
190
191 return 0;
192}
193
194template <int Rank, DType Dtype>
195int OpBitwiseAnd<Rank, Dtype>::register_fcn()
196{
197 switch (Dtype)
198 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800199 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700200 case DType_INT16:
201 case DType_INT32:
202 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
203 break;
204 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700205 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 }
207
208 return 0;
209}
210
211template <int Rank, DType Dtype>
212int OpBitwiseOr<Rank, Dtype>::register_fcn()
213{
214 switch (Dtype)
215 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800216 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700217 case DType_INT16:
218 case DType_INT32:
219 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
220 break;
221 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700222 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700223 }
224
225 return 0;
226}
227
228template <int Rank, DType Dtype>
229int OpBitwiseXor<Rank, Dtype>::register_fcn()
230{
231 switch (Dtype)
232 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800233 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700234 case DType_INT16:
235 case DType_INT32:
236 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
237 break;
238 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700239 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700240 }
241
242 return 0;
243}
244
245template <int Rank, DType Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100246int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700247{
248 switch (InDtype)
249 {
250 case DType_INT32:
251 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100252 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700253 int64_t res_in_64 = static_cast<int64_t>(a) / b;
254 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100255 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
256 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700257 return static_cast<InEigenType>(res_in_64);
258 };
259 break;
260 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700261 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700262 }
263
264 return 0;
265}
266
267template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700268int OpLogicalAnd<Rank, Dtype>::register_fcn()
269{
270 switch (Dtype)
271 {
272 case DType_BOOL:
273 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
274 break;
275 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700276 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700277 }
278
279 return 0;
280}
281
282template <int Rank, DType Dtype>
283int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
284{
Jeremy Johnson66bad802022-01-18 14:48:35 +0000285 int32_t num_bits = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700286 switch (Dtype)
287 {
288 case DType_INT8:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000289 num_bits = 8;
290 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700291 case DType_INT16:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000292 num_bits = 16;
293 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700294 case DType_INT32:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000295 num_bits = 32;
Eric Kunzee5e26762020-10-13 16:11:07 -0700296 break;
297 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700298 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700299 }
Jeremy Johnson66bad802022-01-18 14:48:35 +0000300 this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
301 uint32_t mask = ONES_MASK(num_bits);
302 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
303 (int32_t)b);
304 return (a << b) & mask;
305 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700306
307 return 0;
308}
309
310template <int Rank, DType Dtype>
311int OpLogicalRightShift<Rank, Dtype>::register_fcn()
312{
313 int32_t num_bits = 0;
314 switch (Dtype)
315 {
316 case DType_INT8:
317 num_bits = 8;
318 break;
319 case DType_INT16:
320 num_bits = 16;
321 break;
322 case DType_INT32:
323 num_bits = 32;
324 break;
325 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700326 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700327 }
328
Jeremy Johnson66bad802022-01-18 14:48:35 +0000329 this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Eric Kunzee5e26762020-10-13 16:11:07 -0700330 uint32_t mask = ONES_MASK(num_bits) >> b;
Jeremy Johnson66bad802022-01-18 14:48:35 +0000331 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
332 (int32_t)b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700333 return (a >> b) & mask;
334 };
335
336 return 0;
337}
338
339template <int Rank, DType Dtype>
340int OpLogicalOr<Rank, Dtype>::register_fcn()
341{
342 switch (Dtype)
343 {
344 case DType_BOOL:
345 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
346 break;
347 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700348 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700349 }
350
351 return 0;
352}
353
354template <int Rank, DType Dtype>
355int OpLogicalXor<Rank, Dtype>::register_fcn()
356{
357 switch (Dtype)
358 {
359 case DType_BOOL:
360 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
361 break;
362 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700363 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700364 }
365
366 return 0;
367}
368
369template <int Rank, DType Dtype>
370int OpMaximum<Rank, Dtype>::register_fcn()
371{
372 switch (Dtype)
373 {
James Ward8b390432022-08-12 20:48:56 +0100374 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100375 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100376 case DType_FP32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700377 case DType_INT32:
378 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
379 break;
380 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700381 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700382 }
383
384 return 0;
385}
386
387template <int Rank, DType Dtype>
388int OpMinimum<Rank, Dtype>::register_fcn()
389{
390 switch (Dtype)
391 {
James Ward8b390432022-08-12 20:48:56 +0100392 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100393 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100394 case DType_FP32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700395 case DType_INT32:
396 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
397 break;
398 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700399 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700400 }
401
402 return 0;
403}
404
405template <int Rank, DType InDtype, DType OutDtype>
406int OpMul<Rank, InDtype, OutDtype>::register_fcn()
407{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800408 int32_t shift = attribute->shift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800409
Eric Kunzee5e26762020-10-13 16:11:07 -0700410 switch (InDtype)
411 {
James Ward8b390432022-08-12 20:48:56 +0100412 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100413 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100414 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100415 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a * b); };
Kevin Chengaee1fac2020-11-11 13:54:06 -0800416 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700417 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800418 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
419 int64_t result;
420 if (shift > 0)
421 {
422 int64_t round = 1L << (shift - 1);
Kevin Cheng9257fd52021-04-14 15:55:31 -0700423 result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800424 result = result >> shift;
425
Kevin Chengacb550f2021-06-29 15:32:19 -0700426 REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
427 result, QMin, QMax);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800428 }
429 else
430 {
Kevin Chengc42addc2021-09-28 15:41:57 -0700431 result = static_cast<int64_t>(a) * b;
Jeremy Johnson90347472021-09-06 12:04:07 +0100432 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
433 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
434 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
435 return static_cast<InEigenType>(result);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800436 }
437
438 return static_cast<OutEigenType>(result);
439 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700440 break;
441 case DType_INT8:
442 case DType_INT16:
443 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
444 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
445
446 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
447
448 return clamped_output;
449 };
450 break;
451 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700452 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700453 }
454
455 return 0;
456}
457
458template <int Rank, DType Dtype>
459int OpPow<Rank, Dtype>::register_fcn()
460{
461 switch (Dtype)
462 {
James Ward8b390432022-08-12 20:48:56 +0100463 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100464 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100465 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100466 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(powf(a, b)); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700467 break;
468 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700469 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700470 }
471
472 return 0;
473}
474
475template <int Rank, DType Dtype>
476int OpSub<Rank, Dtype>::register_fcn()
477{
478 switch (InDtype)
479 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700480 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100481 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
482 int64_t res_in_64 = static_cast<int64_t>(a) - b;
483 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
484 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
485 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
486 return static_cast<InEigenType>(res_in_64);
487 };
488 break;
James Ward8b390432022-08-12 20:48:56 +0100489 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100490 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100491 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100492 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a - b); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700493 break;
494 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700495 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700496 }
497
498 return 0;
499}
500
Kevin Cheng571f7182021-05-24 17:20:01 -0700501template <int Rank, DType InDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700502OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
503 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700504 uint64_t id_)
505 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700506{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000507 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700508 setRequiredRank(0, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000509
510 INIT_ATTRIBUTE(Table);
Eric Kunzee5e26762020-10-13 16:11:07 -0700511}
512
Kevin Cheng571f7182021-05-24 17:20:01 -0700513template <int Rank, DType InDtype>
514OpTable<Rank, InDtype>::~OpTable()
Eric Kunzee5e26762020-10-13 16:11:07 -0700515{}
516
Kevin Cheng571f7182021-05-24 17:20:01 -0700517template <int Rank, DType InDtype>
518int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700519{
520 if (validateRequiredOperands())
521 return 1;
522
523 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
524 {
525 return 1;
526 }
527
Kevin Chengfe392ce2021-10-18 21:51:55 +0000528 ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000529 ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type");
Kevin Chengfe392ce2021-10-18 21:51:55 +0000530 ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries);
531
532 for (uint32_t i = 0; i < TableNumEntries; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700533 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000534 table[i] = (TableEigenType)attribute->table()[i];
Eric Kunzee5e26762020-10-13 16:11:07 -0700535 }
536
Kevin Chengfe392ce2021-10-18 21:51:55 +0000537 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
538 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700539
Kevin Chengfe392ce2021-10-18 21:51:55 +0000540 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700541
542 return 0;
543}
544
Kevin Cheng571f7182021-05-24 17:20:01 -0700545template <int Rank, DType InDtype>
546int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700547{
Kevin Cheng571f7182021-05-24 17:20:01 -0700548 switch (InDtype)
549 {
550 case DType_INT8:
551 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
552 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
553 int32_t index = input_truncated - QInMin;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000554 int32_t value = table[index];
Eric Kunzee5e26762020-10-13 16:11:07 -0700555
Kevin Cheng571f7182021-05-24 17:20:01 -0700556 return value;
557 });
558 break;
559 case DType_INT16:
560 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
561 // 1. make sure input is int16 range
562 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700563
Kevin Cheng571f7182021-05-24 17:20:01 -0700564 // 2. calculate index and interpolation fraction
565 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
566 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
567 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700568
Jerry Ged511f9e2022-08-12 16:12:40 -0700569 // 3. Add REQUIRE CHECK for extreme large/small slopes
Kevin Chengfe392ce2021-10-18 21:51:55 +0000570 int32_t base = table[index];
571 int32_t next = table[index + 1];
Jerry Ged511f9e2022-08-12 16:12:40 -0700572 int32_t slope = next - base;
573 REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(), "OpTable: slope out of int16_t range");
574
575 // 4. interpolate, generate 16.7 (23-bit) output
576 int32_t value = (base << 7) + (slope) * frac;
Kevin Cheng571f7182021-05-24 17:20:01 -0700577
578 return value;
579 });
580 break;
581 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700582 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700583 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700584
585 return GraphNode::eval();
586}
587
588// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +0100589DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100590DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100591DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700592DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
593
594DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
595DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
596DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
597
Kevin Cheng3a478572021-01-22 17:21:02 -0800598DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700599DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
600DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
601
Kevin Cheng3a478572021-01-22 17:21:02 -0800602DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700603DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
604DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
605
Kevin Cheng3a478572021-01-22 17:21:02 -0800606DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700607DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
608DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
609
Matthew Haddon459443c2021-08-23 16:43:13 +0100610DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700611
Eric Kunzee5e26762020-10-13 16:11:07 -0700612DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
613
614DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
615DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
616DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
617
618DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
619DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
620DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
621
622DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
623
624DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
625
James Ward8b390432022-08-12 20:48:56 +0100626DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100627DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100628DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700629DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
630
James Ward8b390432022-08-12 20:48:56 +0100631DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100632DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100633DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700634DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
635
James Ward8b390432022-08-12 20:48:56 +0100636DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100637DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100638DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700639DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
640DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
641DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
642
James Ward8b390432022-08-12 20:48:56 +0100643DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100644DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100645DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700646
James Ward8b390432022-08-12 20:48:56 +0100647DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100648DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100649DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700650DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
651
Kevin Cheng571f7182021-05-24 17:20:01 -0700652DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
653DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700654
James Ward8b390432022-08-12 20:48:56 +0100655// Instantiation of nodes for comparison operators opEqual, opGreater
656// and opGreaterEqual
657DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
James Ward24dbc422022-10-19 12:20:31 +0100658DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100659DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL);
Eric Kunzee5e26762020-10-13 16:11:07 -0700660DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);