blob: 57eab5f4753c355ec4c2419b2c1236a7b214f89c [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jeremy Johnsoneef86672023-01-18 16:23:20 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
27 const Op& op_,
Kevin Chengacb550f2021-06-29 15:32:19 -070028 uint64_t id_)
29 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070030{
31 setRequiredOperands(2, 1);
32 setRequiredRank(0, 6);
33
Kevin Chengc42addc2021-09-28 15:41:57 -070034 a = b = nullptr;
35 result = nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -070036
37 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
38}
39
40template <int Rank, DType InDtype, DType OutDtype>
41BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
42{}
43
44template <int Rank, DType InDtype, DType OutDtype>
45int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
46{
47 if (validateRequiredOperands())
48 return 1;
49
50 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
51 {
52 return 1;
53 }
54
Kevin Chengc42addc2021-09-28 15:41:57 -070055 // A & B must be the same rank and types
56 if (inputs[0]->matchRankType(*inputs[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -070057 {
58 printNodeValidationError("Binary operator input types must match");
59 return 1;
60 }
61
Kevin Cheng1c3c8472021-11-08 11:19:10 -080062 if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */))
Kevin Cheng478101b2021-10-04 10:43:14 -070063 {
64 std::string err =
Kevin Cheng1c3c8472021-11-08 11:19:10 -080065 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match";
66 printNodeValidationError(err.c_str());
67 return 1;
68 }
69
70 if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */))
71 {
72 std::string err =
73 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match";
Kevin Cheng478101b2021-10-04 10:43:14 -070074 printNodeValidationError(err.c_str());
75 return 1;
76 }
Eric Kunzee5e26762020-10-13 16:11:07 -070077
Kevin Chengcc61be32021-10-14 17:09:57 -070078 ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
79
Kevin Chengc42addc2021-09-28 15:41:57 -070080 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
81 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
Eric Kunzee5e26762020-10-13 16:11:07 -070082 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
83
Kevin Chengc42addc2021-09-28 15:41:57 -070084 ASSERT_MEM(a && b && result);
Eric Kunzee5e26762020-10-13 16:11:07 -070085
86 return 0;
87}
88
89template <int Rank, DType InDtype, DType OutDtype>
90int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
91{
Kevin Cheng1c3c8472021-11-08 11:19:10 -080092 const std::vector<int>& a_shape = a->getShape();
93 const std::vector<int>& b_shape = b->getShape();
94 const std::vector<int>& output_shape = result->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070095
Kevin Cheng1c3c8472021-11-08 11:19:10 -080096 for (int i = 0; i < Rank; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070097 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080098 bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
99 bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 }
101
102 return 0;
103}
104
105template <int Rank, DType InDtype, DType OutDtype>
106int BinaryNode<Rank, InDtype, OutDtype>::eval()
107{
108 this->broadcast();
109
110 Eigen::array<int, Rank> reshaper;
111 reshaper.fill(1);
112 TIn ia, ib;
113
Kevin Chengc42addc2021-09-28 15:41:57 -0700114 ia = this->a->getTensor().broadcast(this->bcast_a);
115 ib = this->b->getTensor().broadcast(this->bcast_b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700116
117 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
118
119 return GraphNode::eval();
120}
121
122// still need to partial specialize this, or Eigen will throw static assertion
123template <DType InDtype, DType OutDtype>
124int BinaryNode<0, InDtype, OutDtype>::eval()
125{
126 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
127
128 return GraphNode::eval();
129}
130
131template <int Rank, DType Dtype>
132int OpAdd<Rank, Dtype>::register_fcn()
133{
134 switch (InDtype)
135 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700136 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100137 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
138 int64_t res_in_64 = static_cast<int64_t>(a) + b;
139 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
140 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
141 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
142 return static_cast<InEigenType>(res_in_64);
143 };
144 break;
James Ward8b390432022-08-12 20:48:56 +0100145 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100146 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100147 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100148 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a + b); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700149 break;
150 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700151 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700152 }
153
154 return 0;
155}
156
157template <int Rank, DType Dtype>
158int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
159{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800160 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700161 int32_t num_bits = 0;
162 switch (Dtype)
163 {
164 case DType_INT8:
165 num_bits = 8;
166 break;
167 case DType_INT16:
168 num_bits = 16;
169 break;
170 case DType_INT32:
171 num_bits = 32;
172 break;
173 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700174 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700175 }
176
Kevin Chengaee1fac2020-11-11 13:54:06 -0800177 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700178 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
179 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800180
181 InEigenType acc = a >> b;
182
183 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
184 {
185 acc++;
186 }
187
188 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700189 };
190
191 return 0;
192}
193
194template <int Rank, DType Dtype>
Jerry Gea6827492022-11-16 10:41:55 -0800195OpArithmeticRightShift<Rank, Dtype>::~OpArithmeticRightShift()
196{
197 if (attribute) delete attribute;
198}
199
200template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700201int OpBitwiseAnd<Rank, Dtype>::register_fcn()
202{
203 switch (Dtype)
204 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800205 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 case DType_INT16:
207 case DType_INT32:
208 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
209 break;
210 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700211 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 }
213
214 return 0;
215}
216
217template <int Rank, DType Dtype>
218int OpBitwiseOr<Rank, Dtype>::register_fcn()
219{
220 switch (Dtype)
221 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800222 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700223 case DType_INT16:
224 case DType_INT32:
225 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
226 break;
227 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700228 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 }
230
231 return 0;
232}
233
234template <int Rank, DType Dtype>
235int OpBitwiseXor<Rank, Dtype>::register_fcn()
236{
237 switch (Dtype)
238 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800239 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700240 case DType_INT16:
241 case DType_INT32:
242 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
243 break;
244 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700245 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700246 }
247
248 return 0;
249}
250
251template <int Rank, DType Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100252int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700253{
254 switch (InDtype)
255 {
256 case DType_INT32:
257 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100258 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700259 int64_t res_in_64 = static_cast<int64_t>(a) / b;
260 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100261 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
262 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700263 return static_cast<InEigenType>(res_in_64);
264 };
265 break;
266 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700267 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700268 }
269
270 return 0;
271}
272
273template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700274int OpLogicalAnd<Rank, Dtype>::register_fcn()
275{
276 switch (Dtype)
277 {
278 case DType_BOOL:
279 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
280 break;
281 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700282 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700283 }
284
285 return 0;
286}
287
288template <int Rank, DType Dtype>
289int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
290{
291 switch (Dtype)
292 {
293 case DType_INT8:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000294 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
295 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
296 (int32_t)b);
297 return static_cast<OutEigenType>(static_cast<int8_t>(a << b));
298 };
Jeremy Johnson66bad802022-01-18 14:48:35 +0000299 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700300 case DType_INT16:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000301 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
302 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
303 (int32_t)b);
304 return static_cast<OutEigenType>(static_cast<int16_t>(a << b));
305 };
Jeremy Johnson66bad802022-01-18 14:48:35 +0000306 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700307 case DType_INT32:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000308 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
309 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
310 (int32_t)b);
311 return static_cast<OutEigenType>(static_cast<int32_t>(a << b));
312 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700313 break;
314 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700315 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700316 }
317
318 return 0;
319}
320
321template <int Rank, DType Dtype>
322int OpLogicalRightShift<Rank, Dtype>::register_fcn()
323{
Eric Kunzee5e26762020-10-13 16:11:07 -0700324 switch (Dtype)
325 {
326 case DType_INT8:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000327 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
328 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
329 (int32_t)b);
Won Jeoncb7e0292023-06-28 22:34:38 +0000330 return static_cast<OutEigenType>(static_cast<int8_t>(static_cast<uint8_t>(a) >> b));
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000331 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700332 break;
333 case DType_INT16:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000334 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
335 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
336 (int32_t)b);
Won Jeoncb7e0292023-06-28 22:34:38 +0000337 return static_cast<OutEigenType>(static_cast<int16_t>(static_cast<uint16_t>(a) >> b));
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000338 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700339 break;
340 case DType_INT32:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000341 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
342 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
343 (int32_t)b);
Won Jeoncb7e0292023-06-28 22:34:38 +0000344 return static_cast<OutEigenType>(static_cast<int32_t>(static_cast<uint32_t>(a) >> b));
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000345 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700346 break;
347 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700348 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700349 }
350
Eric Kunzee5e26762020-10-13 16:11:07 -0700351 return 0;
352}
353
354template <int Rank, DType Dtype>
355int OpLogicalOr<Rank, Dtype>::register_fcn()
356{
357 switch (Dtype)
358 {
359 case DType_BOOL:
360 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
361 break;
362 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700363 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700364 }
365
366 return 0;
367}
368
369template <int Rank, DType Dtype>
370int OpLogicalXor<Rank, Dtype>::register_fcn()
371{
372 switch (Dtype)
373 {
374 case DType_BOOL:
375 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
376 break;
377 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700378 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700379 }
380
381 return 0;
382}
383
384template <int Rank, DType Dtype>
385int OpMaximum<Rank, Dtype>::register_fcn()
386{
387 switch (Dtype)
388 {
James Ward8b390432022-08-12 20:48:56 +0100389 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100390 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100391 case DType_FP32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700392 case DType_INT32:
393 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
394 break;
395 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700396 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700397 }
398
399 return 0;
400}
401
402template <int Rank, DType Dtype>
403int OpMinimum<Rank, Dtype>::register_fcn()
404{
405 switch (Dtype)
406 {
James Ward8b390432022-08-12 20:48:56 +0100407 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100408 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100409 case DType_FP32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700410 case DType_INT32:
411 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
412 break;
413 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700414 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700415 }
416
417 return 0;
418}
419
420template <int Rank, DType InDtype, DType OutDtype>
421int OpMul<Rank, InDtype, OutDtype>::register_fcn()
422{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800423 int32_t shift = attribute->shift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800424
Eric Kunzee5e26762020-10-13 16:11:07 -0700425 switch (InDtype)
426 {
James Ward8b390432022-08-12 20:48:56 +0100427 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100428 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100429 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100430 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a * b); };
Kevin Chengaee1fac2020-11-11 13:54:06 -0800431 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700432 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800433 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
434 int64_t result;
435 if (shift > 0)
436 {
437 int64_t round = 1L << (shift - 1);
Kevin Cheng9257fd52021-04-14 15:55:31 -0700438 result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800439 result = result >> shift;
440
Kevin Chengacb550f2021-06-29 15:32:19 -0700441 REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
442 result, QMin, QMax);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800443 }
444 else
445 {
Kevin Chengc42addc2021-09-28 15:41:57 -0700446 result = static_cast<int64_t>(a) * b;
Jeremy Johnson90347472021-09-06 12:04:07 +0100447 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
448 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
449 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
450 return static_cast<InEigenType>(result);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800451 }
452
453 return static_cast<OutEigenType>(result);
454 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 break;
456 case DType_INT8:
457 case DType_INT16:
458 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
459 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
460
461 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
462
463 return clamped_output;
464 };
465 break;
466 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700467 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700468 }
469
470 return 0;
471}
472
Jerry Gea6827492022-11-16 10:41:55 -0800473template <int Rank, DType InDtype, DType OutDtype>
474OpMul<Rank, InDtype, OutDtype>::~OpMul()
475{
476 if (attribute) delete attribute;
477}
478
Eric Kunzee5e26762020-10-13 16:11:07 -0700479template <int Rank, DType Dtype>
480int OpPow<Rank, Dtype>::register_fcn()
481{
482 switch (Dtype)
483 {
James Ward8b390432022-08-12 20:48:56 +0100484 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100485 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100486 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100487 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(powf(a, b)); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700488 break;
489 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700490 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700491 }
492
493 return 0;
494}
495
496template <int Rank, DType Dtype>
497int OpSub<Rank, Dtype>::register_fcn()
498{
499 switch (InDtype)
500 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700501 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100502 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
503 int64_t res_in_64 = static_cast<int64_t>(a) - b;
504 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
505 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
506 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
507 return static_cast<InEigenType>(res_in_64);
508 };
509 break;
James Ward8b390432022-08-12 20:48:56 +0100510 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100511 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100512 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100513 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a - b); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700514 break;
515 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700516 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700517 }
518
519 return 0;
520}
521
Kevin Cheng571f7182021-05-24 17:20:01 -0700522template <int Rank, DType InDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700523OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
524 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700525 uint64_t id_)
526 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700527{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000528 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700529 setRequiredRank(0, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000530
531 INIT_ATTRIBUTE(Table);
Eric Kunzee5e26762020-10-13 16:11:07 -0700532}
533
Kevin Cheng571f7182021-05-24 17:20:01 -0700534template <int Rank, DType InDtype>
535OpTable<Rank, InDtype>::~OpTable()
Jerry Gea6827492022-11-16 10:41:55 -0800536{
537 if (attribute) delete attribute;
538}
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
Kevin Cheng571f7182021-05-24 17:20:01 -0700540template <int Rank, DType InDtype>
541int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700542{
543 if (validateRequiredOperands())
544 return 1;
545
546 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
547 {
548 return 1;
549 }
550
Kevin Chengfe392ce2021-10-18 21:51:55 +0000551 ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000552 ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type");
Kevin Chengfe392ce2021-10-18 21:51:55 +0000553 ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries);
554
555 for (uint32_t i = 0; i < TableNumEntries; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700556 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000557 table[i] = (TableEigenType)attribute->table()[i];
Eric Kunzee5e26762020-10-13 16:11:07 -0700558 }
559
Kevin Chengfe392ce2021-10-18 21:51:55 +0000560 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
561 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700562
Kevin Chengfe392ce2021-10-18 21:51:55 +0000563 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700564
565 return 0;
566}
567
Kevin Cheng571f7182021-05-24 17:20:01 -0700568template <int Rank, DType InDtype>
569int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700570{
Kevin Cheng571f7182021-05-24 17:20:01 -0700571 switch (InDtype)
572 {
573 case DType_INT8:
574 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
575 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
576 int32_t index = input_truncated - QInMin;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000577 int32_t value = table[index];
Eric Kunzee5e26762020-10-13 16:11:07 -0700578
Kevin Cheng571f7182021-05-24 17:20:01 -0700579 return value;
580 });
581 break;
582 case DType_INT16:
583 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
584 // 1. make sure input is int16 range
585 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700586
Kevin Cheng571f7182021-05-24 17:20:01 -0700587 // 2. calculate index and interpolation fraction
588 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
589 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
590 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700591
Jerry Ged511f9e2022-08-12 16:12:40 -0700592 // 3. Add REQUIRE CHECK for extreme large/small slopes
Kevin Chengfe392ce2021-10-18 21:51:55 +0000593 int32_t base = table[index];
594 int32_t next = table[index + 1];
Jerry Ged511f9e2022-08-12 16:12:40 -0700595 int32_t slope = next - base;
596 REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(), "OpTable: slope out of int16_t range");
597
598 // 4. interpolate, generate 16.7 (23-bit) output
599 int32_t value = (base << 7) + (slope) * frac;
Kevin Cheng571f7182021-05-24 17:20:01 -0700600
601 return value;
602 });
603 break;
604 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700605 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700606 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700607
608 return GraphNode::eval();
609}
610
611// template explicit instantiation
Jared Smolens98c281f2022-12-20 15:09:25 -0800612DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, FP16);
613DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BF16);
614DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, FP32);
615DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT8, INT8);
616DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT16, INT16);
617DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, INT32);
618DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT8, INT32);
619DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT16, INT32);
620DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BOOL, BOOL);
621DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, BOOL);
622DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BOOL);
623DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, BOOL);
624DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, BOOL);
625
James Ward8b390432022-08-12 20:48:56 +0100626DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100627DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100628DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700629DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
630
631DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
632DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
633DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
634
Kevin Cheng3a478572021-01-22 17:21:02 -0800635DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700636DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
637DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
638
Kevin Cheng3a478572021-01-22 17:21:02 -0800639DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700640DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
641DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
642
Kevin Cheng3a478572021-01-22 17:21:02 -0800643DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700644DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
645DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
646
Matthew Haddon459443c2021-08-23 16:43:13 +0100647DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700648
Eric Kunzee5e26762020-10-13 16:11:07 -0700649DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
650
651DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
652DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
653DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
654
655DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
656DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
657DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
658
659DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
660
661DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
662
James Ward8b390432022-08-12 20:48:56 +0100663DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100664DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100665DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700666DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
667
James Ward8b390432022-08-12 20:48:56 +0100668DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100669DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100670DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700671DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
672
James Ward8b390432022-08-12 20:48:56 +0100673DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100674DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100675DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700676DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
677DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
678DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
679
James Ward8b390432022-08-12 20:48:56 +0100680DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100681DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100682DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700683
James Ward8b390432022-08-12 20:48:56 +0100684DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100685DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100686DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700687DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
688
Kevin Cheng571f7182021-05-24 17:20:01 -0700689DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
690DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700691
James Ward8b390432022-08-12 20:48:56 +0100692// Instantiation of nodes for comparison operators opEqual, opGreater
693// and opGreaterEqual
694DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
James Ward24dbc422022-10-19 12:20:31 +0100695DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100696DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL);
Eric Kunzee5e26762020-10-13 16:11:07 -0700697DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);