blob: 7d0c434e11f10057083dde5f6acab8b02ea08dcf [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
27 const Op& op_,
Kevin Chengacb550f2021-06-29 15:32:19 -070028 uint64_t id_)
29 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070030{
31 setRequiredOperands(2, 1);
32 setRequiredRank(0, 6);
33
Kevin Chengc42addc2021-09-28 15:41:57 -070034 a = b = nullptr;
35 result = nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -070036
37 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
38}
39
40template <int Rank, DType InDtype, DType OutDtype>
41BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
42{}
43
44template <int Rank, DType InDtype, DType OutDtype>
45int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
46{
47 if (validateRequiredOperands())
48 return 1;
49
50 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
51 {
52 return 1;
53 }
54
Kevin Chengc42addc2021-09-28 15:41:57 -070055 // A & B must be the same rank and types
56 if (inputs[0]->matchRankType(*inputs[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -070057 {
58 printNodeValidationError("Binary operator input types must match");
59 return 1;
60 }
61
Kevin Cheng1c3c8472021-11-08 11:19:10 -080062 if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */))
Kevin Cheng478101b2021-10-04 10:43:14 -070063 {
64 std::string err =
Kevin Cheng1c3c8472021-11-08 11:19:10 -080065 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match";
66 printNodeValidationError(err.c_str());
67 return 1;
68 }
69
70 if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */))
71 {
72 std::string err =
73 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match";
Kevin Cheng478101b2021-10-04 10:43:14 -070074 printNodeValidationError(err.c_str());
75 return 1;
76 }
Eric Kunzee5e26762020-10-13 16:11:07 -070077
Kevin Chengcc61be32021-10-14 17:09:57 -070078 ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
79
Kevin Chengc42addc2021-09-28 15:41:57 -070080 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
81 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
Eric Kunzee5e26762020-10-13 16:11:07 -070082 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
83
Kevin Chengc42addc2021-09-28 15:41:57 -070084 ASSERT_MEM(a && b && result);
Eric Kunzee5e26762020-10-13 16:11:07 -070085
86 return 0;
87}
88
89template <int Rank, DType InDtype, DType OutDtype>
90int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
91{
Kevin Cheng1c3c8472021-11-08 11:19:10 -080092 const std::vector<int>& a_shape = a->getShape();
93 const std::vector<int>& b_shape = b->getShape();
94 const std::vector<int>& output_shape = result->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070095
Kevin Cheng1c3c8472021-11-08 11:19:10 -080096 for (int i = 0; i < Rank; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070097 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080098 bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
99 bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 }
101
102 return 0;
103}
104
105template <int Rank, DType InDtype, DType OutDtype>
106int BinaryNode<Rank, InDtype, OutDtype>::eval()
107{
108 this->broadcast();
109
110 Eigen::array<int, Rank> reshaper;
111 reshaper.fill(1);
112 TIn ia, ib;
113
Kevin Chengc42addc2021-09-28 15:41:57 -0700114 ia = this->a->getTensor().broadcast(this->bcast_a);
115 ib = this->b->getTensor().broadcast(this->bcast_b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700116
117 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
118
119 return GraphNode::eval();
120}
121
122// still need to partial specialize this, or Eigen will throw static assertion
123template <DType InDtype, DType OutDtype>
124int BinaryNode<0, InDtype, OutDtype>::eval()
125{
126 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
127
128 return GraphNode::eval();
129}
130
131template <int Rank, DType Dtype>
132int OpAdd<Rank, Dtype>::register_fcn()
133{
134 switch (InDtype)
135 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700136 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100137 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
138 int64_t res_in_64 = static_cast<int64_t>(a) + b;
139 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
140 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
141 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
142 return static_cast<InEigenType>(res_in_64);
143 };
144 break;
James Ward8b390432022-08-12 20:48:56 +0100145 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100146 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100147 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100148 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a + b); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700149 break;
150 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700151 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700152 }
153
154 return 0;
155}
156
157template <int Rank, DType Dtype>
158int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
159{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800160 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700161 int32_t num_bits = 0;
162 switch (Dtype)
163 {
164 case DType_INT8:
165 num_bits = 8;
166 break;
167 case DType_INT16:
168 num_bits = 16;
169 break;
170 case DType_INT32:
171 num_bits = 32;
172 break;
173 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700174 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700175 }
176
Kevin Chengaee1fac2020-11-11 13:54:06 -0800177 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700178 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
179 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800180
181 InEigenType acc = a >> b;
182
183 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
184 {
185 acc++;
186 }
187
188 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700189 };
190
191 return 0;
192}
193
194template <int Rank, DType Dtype>
Jerry Gea6827492022-11-16 10:41:55 -0800195OpArithmeticRightShift<Rank, Dtype>::~OpArithmeticRightShift()
196{
197 if (attribute) delete attribute;
198}
199
200template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700201int OpBitwiseAnd<Rank, Dtype>::register_fcn()
202{
203 switch (Dtype)
204 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800205 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 case DType_INT16:
207 case DType_INT32:
208 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
209 break;
210 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700211 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 }
213
214 return 0;
215}
216
217template <int Rank, DType Dtype>
218int OpBitwiseOr<Rank, Dtype>::register_fcn()
219{
220 switch (Dtype)
221 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800222 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700223 case DType_INT16:
224 case DType_INT32:
225 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
226 break;
227 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700228 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 }
230
231 return 0;
232}
233
234template <int Rank, DType Dtype>
235int OpBitwiseXor<Rank, Dtype>::register_fcn()
236{
237 switch (Dtype)
238 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800239 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700240 case DType_INT16:
241 case DType_INT32:
242 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
243 break;
244 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700245 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700246 }
247
248 return 0;
249}
250
251template <int Rank, DType Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100252int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700253{
254 switch (InDtype)
255 {
256 case DType_INT32:
257 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100258 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700259 int64_t res_in_64 = static_cast<int64_t>(a) / b;
260 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100261 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
262 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700263 return static_cast<InEigenType>(res_in_64);
264 };
265 break;
266 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700267 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700268 }
269
270 return 0;
271}
272
273template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700274int OpLogicalAnd<Rank, Dtype>::register_fcn()
275{
276 switch (Dtype)
277 {
278 case DType_BOOL:
279 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
280 break;
281 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700282 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700283 }
284
285 return 0;
286}
287
288template <int Rank, DType Dtype>
289int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
290{
Jeremy Johnson66bad802022-01-18 14:48:35 +0000291 int32_t num_bits = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 switch (Dtype)
293 {
294 case DType_INT8:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000295 num_bits = 8;
296 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700297 case DType_INT16:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000298 num_bits = 16;
299 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700300 case DType_INT32:
Jeremy Johnson66bad802022-01-18 14:48:35 +0000301 num_bits = 32;
Eric Kunzee5e26762020-10-13 16:11:07 -0700302 break;
303 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700304 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700305 }
Jeremy Johnson66bad802022-01-18 14:48:35 +0000306 this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
307 uint32_t mask = ONES_MASK(num_bits);
308 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
309 (int32_t)b);
310 return (a << b) & mask;
311 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700312
313 return 0;
314}
315
316template <int Rank, DType Dtype>
317int OpLogicalRightShift<Rank, Dtype>::register_fcn()
318{
319 int32_t num_bits = 0;
320 switch (Dtype)
321 {
322 case DType_INT8:
323 num_bits = 8;
324 break;
325 case DType_INT16:
326 num_bits = 16;
327 break;
328 case DType_INT32:
329 num_bits = 32;
330 break;
331 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700332 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700333 }
334
Jeremy Johnson66bad802022-01-18 14:48:35 +0000335 this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Eric Kunzee5e26762020-10-13 16:11:07 -0700336 uint32_t mask = ONES_MASK(num_bits) >> b;
Jeremy Johnson66bad802022-01-18 14:48:35 +0000337 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
338 (int32_t)b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700339 return (a >> b) & mask;
340 };
341
342 return 0;
343}
344
345template <int Rank, DType Dtype>
346int OpLogicalOr<Rank, Dtype>::register_fcn()
347{
348 switch (Dtype)
349 {
350 case DType_BOOL:
351 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
352 break;
353 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700354 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700355 }
356
357 return 0;
358}
359
360template <int Rank, DType Dtype>
361int OpLogicalXor<Rank, Dtype>::register_fcn()
362{
363 switch (Dtype)
364 {
365 case DType_BOOL:
366 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
367 break;
368 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700369 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700370 }
371
372 return 0;
373}
374
375template <int Rank, DType Dtype>
376int OpMaximum<Rank, Dtype>::register_fcn()
377{
378 switch (Dtype)
379 {
James Ward8b390432022-08-12 20:48:56 +0100380 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100381 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100382 case DType_FP32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700383 case DType_INT32:
384 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
385 break;
386 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700387 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700388 }
389
390 return 0;
391}
392
393template <int Rank, DType Dtype>
394int OpMinimum<Rank, Dtype>::register_fcn()
395{
396 switch (Dtype)
397 {
James Ward8b390432022-08-12 20:48:56 +0100398 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100399 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100400 case DType_FP32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 case DType_INT32:
402 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
403 break;
404 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700405 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700406 }
407
408 return 0;
409}
410
411template <int Rank, DType InDtype, DType OutDtype>
412int OpMul<Rank, InDtype, OutDtype>::register_fcn()
413{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800414 int32_t shift = attribute->shift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800415
Eric Kunzee5e26762020-10-13 16:11:07 -0700416 switch (InDtype)
417 {
James Ward8b390432022-08-12 20:48:56 +0100418 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100419 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100420 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100421 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a * b); };
Kevin Chengaee1fac2020-11-11 13:54:06 -0800422 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700423 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800424 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
425 int64_t result;
426 if (shift > 0)
427 {
428 int64_t round = 1L << (shift - 1);
Kevin Cheng9257fd52021-04-14 15:55:31 -0700429 result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800430 result = result >> shift;
431
Kevin Chengacb550f2021-06-29 15:32:19 -0700432 REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
433 result, QMin, QMax);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800434 }
435 else
436 {
Kevin Chengc42addc2021-09-28 15:41:57 -0700437 result = static_cast<int64_t>(a) * b;
Jeremy Johnson90347472021-09-06 12:04:07 +0100438 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
439 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
440 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
441 return static_cast<InEigenType>(result);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800442 }
443
444 return static_cast<OutEigenType>(result);
445 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700446 break;
447 case DType_INT8:
448 case DType_INT16:
449 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
450 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
451
452 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
453
454 return clamped_output;
455 };
456 break;
457 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700458 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700459 }
460
461 return 0;
462}
463
Jerry Gea6827492022-11-16 10:41:55 -0800464template <int Rank, DType InDtype, DType OutDtype>
465OpMul<Rank, InDtype, OutDtype>::~OpMul()
466{
467 if (attribute) delete attribute;
468}
469
Eric Kunzee5e26762020-10-13 16:11:07 -0700470template <int Rank, DType Dtype>
471int OpPow<Rank, Dtype>::register_fcn()
472{
473 switch (Dtype)
474 {
James Ward8b390432022-08-12 20:48:56 +0100475 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100476 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100477 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100478 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(powf(a, b)); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 break;
480 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700481 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700482 }
483
484 return 0;
485}
486
487template <int Rank, DType Dtype>
488int OpSub<Rank, Dtype>::register_fcn()
489{
490 switch (InDtype)
491 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700492 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100493 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
494 int64_t res_in_64 = static_cast<int64_t>(a) - b;
495 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
496 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
497 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
498 return static_cast<InEigenType>(res_in_64);
499 };
500 break;
James Ward8b390432022-08-12 20:48:56 +0100501 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100502 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100503 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100504 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a - b); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700505 break;
506 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700507 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700508 }
509
510 return 0;
511}
512
Kevin Cheng571f7182021-05-24 17:20:01 -0700513template <int Rank, DType InDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700514OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
515 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700516 uint64_t id_)
517 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700518{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000519 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700520 setRequiredRank(0, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000521
522 INIT_ATTRIBUTE(Table);
Eric Kunzee5e26762020-10-13 16:11:07 -0700523}
524
Kevin Cheng571f7182021-05-24 17:20:01 -0700525template <int Rank, DType InDtype>
526OpTable<Rank, InDtype>::~OpTable()
Jerry Gea6827492022-11-16 10:41:55 -0800527{
528 if (attribute) delete attribute;
529}
Eric Kunzee5e26762020-10-13 16:11:07 -0700530
Kevin Cheng571f7182021-05-24 17:20:01 -0700531template <int Rank, DType InDtype>
532int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700533{
534 if (validateRequiredOperands())
535 return 1;
536
537 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
538 {
539 return 1;
540 }
541
Kevin Chengfe392ce2021-10-18 21:51:55 +0000542 ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000543 ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type");
Kevin Chengfe392ce2021-10-18 21:51:55 +0000544 ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries);
545
546 for (uint32_t i = 0; i < TableNumEntries; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700547 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000548 table[i] = (TableEigenType)attribute->table()[i];
Eric Kunzee5e26762020-10-13 16:11:07 -0700549 }
550
Kevin Chengfe392ce2021-10-18 21:51:55 +0000551 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
552 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700553
Kevin Chengfe392ce2021-10-18 21:51:55 +0000554 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700555
556 return 0;
557}
558
Kevin Cheng571f7182021-05-24 17:20:01 -0700559template <int Rank, DType InDtype>
560int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700561{
Kevin Cheng571f7182021-05-24 17:20:01 -0700562 switch (InDtype)
563 {
564 case DType_INT8:
565 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
566 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
567 int32_t index = input_truncated - QInMin;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000568 int32_t value = table[index];
Eric Kunzee5e26762020-10-13 16:11:07 -0700569
Kevin Cheng571f7182021-05-24 17:20:01 -0700570 return value;
571 });
572 break;
573 case DType_INT16:
574 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
575 // 1. make sure input is int16 range
576 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700577
Kevin Cheng571f7182021-05-24 17:20:01 -0700578 // 2. calculate index and interpolation fraction
579 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
580 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
581 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700582
Jerry Ged511f9e2022-08-12 16:12:40 -0700583 // 3. Add REQUIRE CHECK for extreme large/small slopes
Kevin Chengfe392ce2021-10-18 21:51:55 +0000584 int32_t base = table[index];
585 int32_t next = table[index + 1];
Jerry Ged511f9e2022-08-12 16:12:40 -0700586 int32_t slope = next - base;
587 REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(), "OpTable: slope out of int16_t range");
588
589 // 4. interpolate, generate 16.7 (23-bit) output
590 int32_t value = (base << 7) + (slope) * frac;
Kevin Cheng571f7182021-05-24 17:20:01 -0700591
592 return value;
593 });
594 break;
595 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700596 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700597 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700598
599 return GraphNode::eval();
600}
601
602// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +0100603DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100604DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100605DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700606DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
607
608DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
609DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
610DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
611
Kevin Cheng3a478572021-01-22 17:21:02 -0800612DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700613DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
614DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
615
Kevin Cheng3a478572021-01-22 17:21:02 -0800616DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700617DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
618DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
619
Kevin Cheng3a478572021-01-22 17:21:02 -0800620DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700621DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
622DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
623
Matthew Haddon459443c2021-08-23 16:43:13 +0100624DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700625
Eric Kunzee5e26762020-10-13 16:11:07 -0700626DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
627
628DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
629DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
630DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
631
632DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
633DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
634DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
635
636DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
637
638DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
639
James Ward8b390432022-08-12 20:48:56 +0100640DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100641DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100642DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700643DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
644
James Ward8b390432022-08-12 20:48:56 +0100645DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100646DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100647DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700648DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
649
James Ward8b390432022-08-12 20:48:56 +0100650DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100651DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100652DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700653DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
654DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
655DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
656
James Ward8b390432022-08-12 20:48:56 +0100657DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100658DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100659DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700660
James Ward8b390432022-08-12 20:48:56 +0100661DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100662DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100663DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700664DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
665
Kevin Cheng571f7182021-05-24 17:20:01 -0700666DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
667DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700668
James Ward8b390432022-08-12 20:48:56 +0100669// Instantiation of nodes for comparison operators opEqual, opGreater
670// and opGreaterEqual
671DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
James Ward24dbc422022-10-19 12:20:31 +0100672DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100673DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL);
Eric Kunzee5e26762020-10-13 16:11:07 -0700674DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);