blob: 6aa0c0f62d4ad24d4516eb53a3be0cfef8857133 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jeremy Johnsoneef86672023-01-18 16:23:20 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
27 const Op& op_,
Kevin Chengacb550f2021-06-29 15:32:19 -070028 uint64_t id_)
29 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070030{
31 setRequiredOperands(2, 1);
32 setRequiredRank(0, 6);
33
Kevin Chengc42addc2021-09-28 15:41:57 -070034 a = b = nullptr;
35 result = nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -070036
37 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
38}
39
40template <int Rank, DType InDtype, DType OutDtype>
41BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
42{}
43
44template <int Rank, DType InDtype, DType OutDtype>
45int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
46{
Jerry Gea793f462023-04-11 00:05:02 +000047 // Check Tosa Level
48 auto tosa_level = g_func_config.tosa_level;
49 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
50
Eric Kunzee5e26762020-10-13 16:11:07 -070051 if (validateRequiredOperands())
52 return 1;
53
54 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
55 {
56 return 1;
57 }
58
Kevin Chengc42addc2021-09-28 15:41:57 -070059 // A & B must be the same rank and types
60 if (inputs[0]->matchRankType(*inputs[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -070061 {
62 printNodeValidationError("Binary operator input types must match");
63 return 1;
64 }
65
Kevin Cheng1c3c8472021-11-08 11:19:10 -080066 if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */))
Kevin Cheng478101b2021-10-04 10:43:14 -070067 {
68 std::string err =
Kevin Cheng1c3c8472021-11-08 11:19:10 -080069 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match";
70 printNodeValidationError(err.c_str());
71 return 1;
72 }
73
74 if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */))
75 {
76 std::string err =
77 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match";
Kevin Cheng478101b2021-10-04 10:43:14 -070078 printNodeValidationError(err.c_str());
79 return 1;
80 }
Eric Kunzee5e26762020-10-13 16:11:07 -070081
Kevin Chengcc61be32021-10-14 17:09:57 -070082 ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
83
Kevin Chengc42addc2021-09-28 15:41:57 -070084 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
85 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
Eric Kunzee5e26762020-10-13 16:11:07 -070086 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
87
Kevin Chengc42addc2021-09-28 15:41:57 -070088 ASSERT_MEM(a && b && result);
Eric Kunzee5e26762020-10-13 16:11:07 -070089
90 return 0;
91}
92
93template <int Rank, DType InDtype, DType OutDtype>
94int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
95{
Kevin Cheng1c3c8472021-11-08 11:19:10 -080096 const std::vector<int>& a_shape = a->getShape();
97 const std::vector<int>& b_shape = b->getShape();
98 const std::vector<int>& output_shape = result->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070099
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800100 for (int i = 0; i < Rank; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700101 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800102 bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
103 bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700104 }
105
106 return 0;
107}
108
109template <int Rank, DType InDtype, DType OutDtype>
110int BinaryNode<Rank, InDtype, OutDtype>::eval()
111{
112 this->broadcast();
113
114 Eigen::array<int, Rank> reshaper;
115 reshaper.fill(1);
116 TIn ia, ib;
117
Kevin Chengc42addc2021-09-28 15:41:57 -0700118 ia = this->a->getTensor().broadcast(this->bcast_a);
119 ib = this->b->getTensor().broadcast(this->bcast_b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
122
123 return GraphNode::eval();
124}
125
126// still need to partial specialize this, or Eigen will throw static assertion
127template <DType InDtype, DType OutDtype>
128int BinaryNode<0, InDtype, OutDtype>::eval()
129{
130 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
131
132 return GraphNode::eval();
133}
134
135template <int Rank, DType Dtype>
136int OpAdd<Rank, Dtype>::register_fcn()
137{
138 switch (InDtype)
139 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700140 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100141 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
142 int64_t res_in_64 = static_cast<int64_t>(a) + b;
143 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
144 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
145 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
146 return static_cast<InEigenType>(res_in_64);
147 };
148 break;
James Ward8b390432022-08-12 20:48:56 +0100149 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100150 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100151 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100152 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a + b); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700153 break;
154 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700155 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700156 }
157
158 return 0;
159}
160
161template <int Rank, DType Dtype>
162int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
163{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800164 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700165 int32_t num_bits = 0;
166 switch (Dtype)
167 {
168 case DType_INT8:
169 num_bits = 8;
170 break;
171 case DType_INT16:
172 num_bits = 16;
173 break;
174 case DType_INT32:
175 num_bits = 32;
176 break;
177 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700178 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 }
180
Kevin Chengaee1fac2020-11-11 13:54:06 -0800181 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700182 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
183 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800184
185 InEigenType acc = a >> b;
186
187 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
188 {
189 acc++;
190 }
191
192 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 };
194
195 return 0;
196}
197
198template <int Rank, DType Dtype>
Jerry Gea6827492022-11-16 10:41:55 -0800199OpArithmeticRightShift<Rank, Dtype>::~OpArithmeticRightShift()
200{
201 if (attribute) delete attribute;
202}
203
204template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700205int OpBitwiseAnd<Rank, Dtype>::register_fcn()
206{
207 switch (Dtype)
208 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800209 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700210 case DType_INT16:
211 case DType_INT32:
212 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
213 break;
214 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700215 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 }
217
218 return 0;
219}
220
221template <int Rank, DType Dtype>
222int OpBitwiseOr<Rank, Dtype>::register_fcn()
223{
224 switch (Dtype)
225 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800226 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700227 case DType_INT16:
228 case DType_INT32:
229 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
230 break;
231 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700232 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700233 }
234
235 return 0;
236}
237
238template <int Rank, DType Dtype>
239int OpBitwiseXor<Rank, Dtype>::register_fcn()
240{
241 switch (Dtype)
242 {
Kevin Cheng3a478572021-01-22 17:21:02 -0800243 case DType_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700244 case DType_INT16:
245 case DType_INT32:
246 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
247 break;
248 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700249 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700250 }
251
252 return 0;
253}
254
255template <int Rank, DType Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100256int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700257{
258 switch (InDtype)
259 {
260 case DType_INT32:
261 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100262 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700263 int64_t res_in_64 = static_cast<int64_t>(a) / b;
264 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100265 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
266 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700267 return static_cast<InEigenType>(res_in_64);
268 };
269 break;
270 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700271 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700272 }
273
274 return 0;
275}
276
277template <int Rank, DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700278int OpLogicalAnd<Rank, Dtype>::register_fcn()
279{
280 switch (Dtype)
281 {
282 case DType_BOOL:
283 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
284 break;
285 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700286 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700287 }
288
289 return 0;
290}
291
292template <int Rank, DType Dtype>
293int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
294{
295 switch (Dtype)
296 {
297 case DType_INT8:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000298 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
299 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
300 (int32_t)b);
301 return static_cast<OutEigenType>(static_cast<int8_t>(a << b));
302 };
Jeremy Johnson66bad802022-01-18 14:48:35 +0000303 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 case DType_INT16:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000305 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
306 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
307 (int32_t)b);
308 return static_cast<OutEigenType>(static_cast<int16_t>(a << b));
309 };
Jeremy Johnson66bad802022-01-18 14:48:35 +0000310 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700311 case DType_INT32:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000312 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
313 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
314 (int32_t)b);
315 return static_cast<OutEigenType>(static_cast<int32_t>(a << b));
316 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700317 break;
318 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700319 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700320 }
321
322 return 0;
323}
324
325template <int Rank, DType Dtype>
326int OpLogicalRightShift<Rank, Dtype>::register_fcn()
327{
Eric Kunzee5e26762020-10-13 16:11:07 -0700328 switch (Dtype)
329 {
330 case DType_INT8:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000331 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
332 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
333 (int32_t)b);
334 return static_cast<OutEigenType>(static_cast<int8_t>(a) >> b);
335 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700336 break;
337 case DType_INT16:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000338 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
339 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
340 (int32_t)b);
341 return static_cast<OutEigenType>(static_cast<int16_t>(a) >> b);
342 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700343 break;
344 case DType_INT32:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000345 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
346 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
347 (int32_t)b);
348 return static_cast<OutEigenType>(static_cast<int32_t>(a) >> b);
349 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700350 break;
351 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700352 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700353 }
354
Eric Kunzee5e26762020-10-13 16:11:07 -0700355 return 0;
356}
357
358template <int Rank, DType Dtype>
359int OpLogicalOr<Rank, Dtype>::register_fcn()
360{
361 switch (Dtype)
362 {
363 case DType_BOOL:
364 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
365 break;
366 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700367 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700368 }
369
370 return 0;
371}
372
373template <int Rank, DType Dtype>
374int OpLogicalXor<Rank, Dtype>::register_fcn()
375{
376 switch (Dtype)
377 {
378 case DType_BOOL:
379 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
380 break;
381 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700382 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700383 }
384
385 return 0;
386}
387
388template <int Rank, DType Dtype>
389int OpMaximum<Rank, Dtype>::register_fcn()
390{
391 switch (Dtype)
392 {
James Ward8b390432022-08-12 20:48:56 +0100393 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100394 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100395 case DType_FP32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700396 case DType_INT32:
397 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
398 break;
399 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700400 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 }
402
403 return 0;
404}
405
406template <int Rank, DType Dtype>
407int OpMinimum<Rank, Dtype>::register_fcn()
408{
409 switch (Dtype)
410 {
James Ward8b390432022-08-12 20:48:56 +0100411 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100412 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100413 case DType_FP32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700414 case DType_INT32:
415 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
416 break;
417 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700418 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700419 }
420
421 return 0;
422}
423
424template <int Rank, DType InDtype, DType OutDtype>
425int OpMul<Rank, InDtype, OutDtype>::register_fcn()
426{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800427 int32_t shift = attribute->shift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800428
Eric Kunzee5e26762020-10-13 16:11:07 -0700429 switch (InDtype)
430 {
James Ward8b390432022-08-12 20:48:56 +0100431 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100432 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100433 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100434 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a * b); };
Kevin Chengaee1fac2020-11-11 13:54:06 -0800435 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700436 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800437 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
438 int64_t result;
439 if (shift > 0)
440 {
Jerry Gecf305db2023-03-06 13:07:36 -0800441 int64_t round = INT64_C(1) << (shift - 1);
Kevin Cheng9257fd52021-04-14 15:55:31 -0700442 result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800443 result = result >> shift;
444
Kevin Chengacb550f2021-06-29 15:32:19 -0700445 REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
446 result, QMin, QMax);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800447 }
448 else
449 {
Kevin Chengc42addc2021-09-28 15:41:57 -0700450 result = static_cast<int64_t>(a) * b;
Jeremy Johnson90347472021-09-06 12:04:07 +0100451 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
452 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
453 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
454 return static_cast<InEigenType>(result);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800455 }
456
457 return static_cast<OutEigenType>(result);
458 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700459 break;
460 case DType_INT8:
461 case DType_INT16:
462 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
463 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
464
465 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
466
467 return clamped_output;
468 };
469 break;
470 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700471 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700472 }
473
474 return 0;
475}
476
Jerry Gea6827492022-11-16 10:41:55 -0800477template <int Rank, DType InDtype, DType OutDtype>
478OpMul<Rank, InDtype, OutDtype>::~OpMul()
479{
480 if (attribute) delete attribute;
481}
482
Eric Kunzee5e26762020-10-13 16:11:07 -0700483template <int Rank, DType Dtype>
484int OpPow<Rank, Dtype>::register_fcn()
485{
486 switch (Dtype)
487 {
James Ward8b390432022-08-12 20:48:56 +0100488 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100489 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100490 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100491 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(powf(a, b)); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700492 break;
493 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700494 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700495 }
496
497 return 0;
498}
499
500template <int Rank, DType Dtype>
501int OpSub<Rank, Dtype>::register_fcn()
502{
503 switch (InDtype)
504 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700505 case DType_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100506 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
507 int64_t res_in_64 = static_cast<int64_t>(a) - b;
508 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
509 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
510 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
511 return static_cast<InEigenType>(res_in_64);
512 };
513 break;
James Ward8b390432022-08-12 20:48:56 +0100514 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100515 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100516 case DType_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100517 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a - b); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700518 break;
519 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700520 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700521 }
522
523 return 0;
524}
525
Kevin Cheng571f7182021-05-24 17:20:01 -0700526template <int Rank, DType InDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700527OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
528 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700529 uint64_t id_)
530 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700531{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000532 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700533 setRequiredRank(0, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000534
535 INIT_ATTRIBUTE(Table);
Eric Kunzee5e26762020-10-13 16:11:07 -0700536}
537
Kevin Cheng571f7182021-05-24 17:20:01 -0700538template <int Rank, DType InDtype>
539OpTable<Rank, InDtype>::~OpTable()
Jerry Gea6827492022-11-16 10:41:55 -0800540{
541 if (attribute) delete attribute;
542}
Eric Kunzee5e26762020-10-13 16:11:07 -0700543
Kevin Cheng571f7182021-05-24 17:20:01 -0700544template <int Rank, DType InDtype>
545int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700546{
Jerry Gea793f462023-04-11 00:05:02 +0000547 // Check Tosa Level
548 auto tosa_level = g_func_config.tosa_level;
549 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
550
Eric Kunzee5e26762020-10-13 16:11:07 -0700551 if (validateRequiredOperands())
552 return 1;
553
554 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
555 {
556 return 1;
557 }
558
Kevin Chengfe392ce2021-10-18 21:51:55 +0000559 ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000560 ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type");
Kevin Chengfe392ce2021-10-18 21:51:55 +0000561 ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries);
562
563 for (uint32_t i = 0; i < TableNumEntries; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700564 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000565 table[i] = (TableEigenType)attribute->table()[i];
Eric Kunzee5e26762020-10-13 16:11:07 -0700566 }
567
Kevin Chengfe392ce2021-10-18 21:51:55 +0000568 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
569 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700570
Kevin Chengfe392ce2021-10-18 21:51:55 +0000571 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700572
573 return 0;
574}
575
Kevin Cheng571f7182021-05-24 17:20:01 -0700576template <int Rank, DType InDtype>
577int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700578{
Kevin Cheng571f7182021-05-24 17:20:01 -0700579 switch (InDtype)
580 {
581 case DType_INT8:
582 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
583 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
584 int32_t index = input_truncated - QInMin;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000585 int32_t value = table[index];
Eric Kunzee5e26762020-10-13 16:11:07 -0700586
Kevin Cheng571f7182021-05-24 17:20:01 -0700587 return value;
588 });
589 break;
590 case DType_INT16:
591 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
592 // 1. make sure input is int16 range
593 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700594
Kevin Cheng571f7182021-05-24 17:20:01 -0700595 // 2. calculate index and interpolation fraction
596 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
597 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
598 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700599
Jerry Ged511f9e2022-08-12 16:12:40 -0700600 // 3. Add REQUIRE CHECK for extreme large/small slopes
Kevin Chengfe392ce2021-10-18 21:51:55 +0000601 int32_t base = table[index];
602 int32_t next = table[index + 1];
Jerry Ged511f9e2022-08-12 16:12:40 -0700603 int32_t slope = next - base;
604 REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(), "OpTable: slope out of int16_t range");
605
606 // 4. interpolate, generate 16.7 (23-bit) output
607 int32_t value = (base << 7) + (slope) * frac;
Kevin Cheng571f7182021-05-24 17:20:01 -0700608
609 return value;
610 });
611 break;
612 default:
Kevin Chengacb550f2021-06-29 15:32:19 -0700613 ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700614 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700615
616 return GraphNode::eval();
617}
618
619// template explicit instantiation
Jared Smolens98c281f2022-12-20 15:09:25 -0800620DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, FP16);
621DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BF16);
622DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, FP32);
623DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT8, INT8);
624DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT16, INT16);
625DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, INT32);
626DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT8, INT32);
627DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT16, INT32);
628DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BOOL, BOOL);
629DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, BOOL);
630DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BOOL);
631DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, BOOL);
632DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, BOOL);
633
James Ward8b390432022-08-12 20:48:56 +0100634DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100635DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100636DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700637DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
638
639DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
640DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
641DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
642
Kevin Cheng3a478572021-01-22 17:21:02 -0800643DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700644DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
645DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
646
Kevin Cheng3a478572021-01-22 17:21:02 -0800647DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700648DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
649DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
650
Kevin Cheng3a478572021-01-22 17:21:02 -0800651DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700652DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
653DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
654
Matthew Haddon459443c2021-08-23 16:43:13 +0100655DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700656
Eric Kunzee5e26762020-10-13 16:11:07 -0700657DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
658
659DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
660DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
661DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
662
663DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
664DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
665DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
666
667DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
668
669DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
670
James Ward8b390432022-08-12 20:48:56 +0100671DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100672DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100673DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700674DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
675
James Ward8b390432022-08-12 20:48:56 +0100676DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100677DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100678DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700679DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
680
James Ward8b390432022-08-12 20:48:56 +0100681DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100682DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100683DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700684DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
685DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
686DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
687
James Ward8b390432022-08-12 20:48:56 +0100688DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100689DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100690DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700691
James Ward8b390432022-08-12 20:48:56 +0100692DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100693DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100694DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700695DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
696
Kevin Cheng571f7182021-05-24 17:20:01 -0700697DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
698DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700699
James Ward8b390432022-08-12 20:48:56 +0100700// Instantiation of nodes for comparison operators opEqual, opGreater
701// and opGreaterEqual
702DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
James Ward24dbc422022-10-19 12:20:31 +0100703DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100704DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL);
Eric Kunzee5e26762020-10-13 16:11:07 -0700705DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);