blob: bc6353560ef1ce58e774395726204042c70aa470 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Won Jeon66704d52023-06-28 22:34:38 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
Tai Lya4d748b2023-03-28 22:06:56 +000025template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +000026BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070027 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070028{
29 setRequiredOperands(2, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070030
Kevin Chengc42addc2021-09-28 15:41:57 -070031 a = b = nullptr;
32 result = nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
35}
36
Tai Lya4d748b2023-03-28 22:06:56 +000037template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070038BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
39{}
40
Tai Lya4d748b2023-03-28 22:06:56 +000041template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070042int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
43{
Jerry Gea793f462023-04-11 00:05:02 +000044 // Check Tosa Level
45 auto tosa_level = g_func_config.tosa_level;
46 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
47
Eric Kunzee5e26762020-10-13 16:11:07 -070048 if (validateRequiredOperands())
49 return 1;
50
Kevin Chengc42addc2021-09-28 15:41:57 -070051 // A & B must be the same rank and types
52 if (inputs[0]->matchRankType(*inputs[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -070053 {
54 printNodeValidationError("Binary operator input types must match");
55 return 1;
56 }
57
Kevin Cheng1c3c8472021-11-08 11:19:10 -080058 if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */))
Kevin Cheng478101b2021-10-04 10:43:14 -070059 {
60 std::string err =
Kevin Cheng1c3c8472021-11-08 11:19:10 -080061 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match";
62 printNodeValidationError(err.c_str());
63 return 1;
64 }
65
66 if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */))
67 {
68 std::string err =
69 "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match";
Kevin Cheng478101b2021-10-04 10:43:14 -070070 printNodeValidationError(err.c_str());
71 return 1;
72 }
Eric Kunzee5e26762020-10-13 16:11:07 -070073
Kevin Chengcc61be32021-10-14 17:09:57 -070074 ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
75
Kevin Chengc42addc2021-09-28 15:41:57 -070076 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
77 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
Eric Kunzee5e26762020-10-13 16:11:07 -070078 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
79
Kevin Chengc42addc2021-09-28 15:41:57 -070080 ASSERT_MEM(a && b && result);
Eric Kunzee5e26762020-10-13 16:11:07 -070081
82 return 0;
83}
84
Tai Lya4d748b2023-03-28 22:06:56 +000085template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge135c9552023-05-23 20:59:32 +000086int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast(std::vector<int>& calculated_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -070087{
Kevin Cheng1c3c8472021-11-08 11:19:10 -080088 const std::vector<int>& a_shape = a->getShape();
89 const std::vector<int>& b_shape = b->getShape();
90 const std::vector<int>& output_shape = result->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070091
Jerry Ge135c9552023-05-23 20:59:32 +000092 // calculates the multipliers for Eigen
Kevin Cheng1c3c8472021-11-08 11:19:10 -080093 for (int i = 0; i < Rank; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070094 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080095 bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
96 bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -070097 }
98
Jerry Ge135c9552023-05-23 20:59:32 +000099 // calculates the broadcasted output shape
100 calculated_shape = a_shape;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000101 for (size_t i = 0; i < calculated_shape.size(); i++)
102 {
103 if (calculated_shape[i] == 1)
104 {
Jerry Ge135c9552023-05-23 20:59:32 +0000105 calculated_shape[i] = b_shape[i];
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000106 }
107 else
108 {
109 ERROR_IF(b_shape[i] != 1 && b_shape[i] != calculated_shape[i],
110 "Broadcast_shape failure, input shapes are not compatible");
Jerry Ge135c9552023-05-23 20:59:32 +0000111 }
112 }
113
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 return 0;
115}
116
Tai Lya4d748b2023-03-28 22:06:56 +0000117template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700118int BinaryNode<Rank, InDtype, OutDtype>::eval()
119{
Jerry Ge135c9552023-05-23 20:59:32 +0000120 std::vector<int> calculated_shape;
121 this->broadcast(calculated_shape);
122
123 auto result_shape = this->result->getShape();
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000124 ERROR_IF(calculated_shape != result_shape,
125 "Broadcast_shape failure, calculated_shape and result_shape don't match");
Eric Kunzee5e26762020-10-13 16:11:07 -0700126
127 Eigen::array<int, Rank> reshaper;
128 reshaper.fill(1);
129 TIn ia, ib;
130
Kevin Chengc42addc2021-09-28 15:41:57 -0700131 ia = this->a->getTensor().broadcast(this->bcast_a);
132 ib = this->b->getTensor().broadcast(this->bcast_b);
Eric Kunzee5e26762020-10-13 16:11:07 -0700133
134 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
135
136 return GraphNode::eval();
137}
138
139// still need to partial specialize this, or Eigen will throw static assertion
Tai Lya4d748b2023-03-28 22:06:56 +0000140template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700141int BinaryNode<0, InDtype, OutDtype>::eval()
142{
143 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
144
145 return GraphNode::eval();
146}
147
Tai Lya4d748b2023-03-28 22:06:56 +0000148template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700149int OpAdd<Rank, Dtype>::register_fcn()
150{
151 switch (InDtype)
152 {
Tai Lya4d748b2023-03-28 22:06:56 +0000153 case TOSA_REF_TYPE_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100154 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
155 int64_t res_in_64 = static_cast<int64_t>(a) + b;
156 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
157 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
158 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range");
159 return static_cast<InEigenType>(res_in_64);
160 };
161 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000162 case TOSA_REF_TYPE_FP16:
163 case TOSA_REF_TYPE_BF16:
164 case TOSA_REF_TYPE_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100165 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a + b); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000167 case TOSA_REF_TYPE_FP64:
168 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
169 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700170 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000171 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700172 }
173
174 return 0;
175}
176
Tai Lya4d748b2023-03-28 22:06:56 +0000177template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700178int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
179{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800180 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 int32_t num_bits = 0;
182 switch (Dtype)
183 {
Tai Lya4d748b2023-03-28 22:06:56 +0000184 case TOSA_REF_TYPE_INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 num_bits = 8;
186 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000187 case TOSA_REF_TYPE_INT16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700188 num_bits = 16;
189 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000190 case TOSA_REF_TYPE_INT32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700191 num_bits = 32;
192 break;
193 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000194 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700195 }
196
Kevin Chengaee1fac2020-11-11 13:54:06 -0800197 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
Kevin Chengacb550f2021-06-29 15:32:19 -0700198 REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
199 (int32_t)b, num_bits);
Kevin Chengaee1fac2020-11-11 13:54:06 -0800200
201 InEigenType acc = a >> b;
202
203 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
204 {
205 acc++;
206 }
207
208 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700209 };
210
211 return 0;
212}
213
Tai Lya4d748b2023-03-28 22:06:56 +0000214template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Gea6827492022-11-16 10:41:55 -0800215OpArithmeticRightShift<Rank, Dtype>::~OpArithmeticRightShift()
216{
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000217 if (attribute)
218 delete attribute;
Jerry Gea6827492022-11-16 10:41:55 -0800219}
220
Tai Lya4d748b2023-03-28 22:06:56 +0000221template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700222int OpBitwiseAnd<Rank, Dtype>::register_fcn()
223{
224 switch (Dtype)
225 {
Tai Lya4d748b2023-03-28 22:06:56 +0000226 case TOSA_REF_TYPE_INT8:
227 case TOSA_REF_TYPE_INT16:
228 case TOSA_REF_TYPE_INT32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
230 break;
231 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000232 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700233 }
234
235 return 0;
236}
237
Tai Lya4d748b2023-03-28 22:06:56 +0000238template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700239int OpBitwiseOr<Rank, Dtype>::register_fcn()
240{
241 switch (Dtype)
242 {
Tai Lya4d748b2023-03-28 22:06:56 +0000243 case TOSA_REF_TYPE_INT8:
244 case TOSA_REF_TYPE_INT16:
245 case TOSA_REF_TYPE_INT32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700246 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
247 break;
248 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000249 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700250 }
251
252 return 0;
253}
254
Tai Lya4d748b2023-03-28 22:06:56 +0000255template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700256int OpBitwiseXor<Rank, Dtype>::register_fcn()
257{
258 switch (Dtype)
259 {
Tai Lya4d748b2023-03-28 22:06:56 +0000260 case TOSA_REF_TYPE_INT8:
261 case TOSA_REF_TYPE_INT16:
262 case TOSA_REF_TYPE_INT32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700263 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
264 break;
265 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000266 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700267 }
268
269 return 0;
270}
271
Tai Lya4d748b2023-03-28 22:06:56 +0000272template <int Rank, TOSA_REF_TYPE Dtype>
Matthew Haddon459443c2021-08-23 16:43:13 +0100273int OpIntdiv<Rank, Dtype>::register_fcn()
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700274{
275 switch (InDtype)
276 {
Tai Lya4d748b2023-03-28 22:06:56 +0000277 case TOSA_REF_TYPE_INT32:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700278 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
Matthew Haddon459443c2021-08-23 16:43:13 +0100279 REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700280 int64_t res_in_64 = static_cast<int64_t>(a) / b;
281 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
Jeremy Johnson90347472021-09-06 12:04:07 +0100282 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
283 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range");
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700284 return static_cast<InEigenType>(res_in_64);
285 };
286 break;
287 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000288 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700289 }
290
291 return 0;
292}
293
Tai Lya4d748b2023-03-28 22:06:56 +0000294template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700295int OpLogicalAnd<Rank, Dtype>::register_fcn()
296{
297 switch (Dtype)
298 {
Tai Lya4d748b2023-03-28 22:06:56 +0000299 case TOSA_REF_TYPE_BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700300 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
301 break;
302 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000303 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 }
305
306 return 0;
307}
308
Tai Lya4d748b2023-03-28 22:06:56 +0000309template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700310int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
311{
312 switch (Dtype)
313 {
Tai Lya4d748b2023-03-28 22:06:56 +0000314 case TOSA_REF_TYPE_INT8:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000315 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
316 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000317 (int32_t)b);
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000318 return static_cast<OutEigenType>(static_cast<int8_t>(a << b));
319 };
Jeremy Johnson66bad802022-01-18 14:48:35 +0000320 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000321 case TOSA_REF_TYPE_INT16:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000322 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
323 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000324 (int32_t)b);
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000325 return static_cast<OutEigenType>(static_cast<int16_t>(a << b));
326 };
Jeremy Johnson66bad802022-01-18 14:48:35 +0000327 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000328 case TOSA_REF_TYPE_INT32:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000329 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
330 REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000331 (int32_t)b);
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000332 return static_cast<OutEigenType>(static_cast<int32_t>(a << b));
333 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700334 break;
335 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000336 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700337 }
338
339 return 0;
340}
341
Tai Lya4d748b2023-03-28 22:06:56 +0000342template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700343int OpLogicalRightShift<Rank, Dtype>::register_fcn()
344{
Eric Kunzee5e26762020-10-13 16:11:07 -0700345 switch (Dtype)
346 {
Tai Lya4d748b2023-03-28 22:06:56 +0000347 case TOSA_REF_TYPE_INT8:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000348 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
349 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000350 (int32_t)b);
Won Jeon66704d52023-06-28 22:34:38 +0000351 return static_cast<OutEigenType>(static_cast<int8_t>(static_cast<uint8_t>(a) >> b));
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000352 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700353 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000354 case TOSA_REF_TYPE_INT16:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000355 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
356 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000357 (int32_t)b);
Won Jeon66704d52023-06-28 22:34:38 +0000358 return static_cast<OutEigenType>(static_cast<int16_t>(static_cast<uint16_t>(a) >> b));
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000359 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700360 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000361 case TOSA_REF_TYPE_INT32:
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000362 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
363 REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000364 (int32_t)b);
Won Jeon66704d52023-06-28 22:34:38 +0000365 return static_cast<OutEigenType>(static_cast<int32_t>(static_cast<uint32_t>(a) >> b));
Jeremy Johnsoneef86672023-01-18 16:23:20 +0000366 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700367 break;
368 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000369 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700370 }
371
Eric Kunzee5e26762020-10-13 16:11:07 -0700372 return 0;
373}
374
Tai Lya4d748b2023-03-28 22:06:56 +0000375template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700376int OpLogicalOr<Rank, Dtype>::register_fcn()
377{
378 switch (Dtype)
379 {
Tai Lya4d748b2023-03-28 22:06:56 +0000380 case TOSA_REF_TYPE_BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700381 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
382 break;
383 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000384 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700385 }
386
387 return 0;
388}
389
Tai Lya4d748b2023-03-28 22:06:56 +0000390template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700391int OpLogicalXor<Rank, Dtype>::register_fcn()
392{
393 switch (Dtype)
394 {
Tai Lya4d748b2023-03-28 22:06:56 +0000395 case TOSA_REF_TYPE_BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700396 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
397 break;
398 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000399 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700400 }
401
402 return 0;
403}
404
Tai Lya4d748b2023-03-28 22:06:56 +0000405template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700406int OpMaximum<Rank, Dtype>::register_fcn()
407{
408 switch (Dtype)
409 {
Tai Lya4d748b2023-03-28 22:06:56 +0000410 case TOSA_REF_TYPE_FP16:
411 case TOSA_REF_TYPE_BF16:
412 case TOSA_REF_TYPE_FP32:
413 case TOSA_REF_TYPE_FP64:
Jeremy Johnson29b02012024-04-30 13:56:20 +0100414 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType {
415 if (isnan(a))
416 {
417 return a;
418 }
419 else if (isnan(b))
420 {
421 return b;
422 }
423 else
424 {
425 return a > b ? a : b;
426 }
427 };
428 break;
429
Tai Lya4d748b2023-03-28 22:06:56 +0000430 case TOSA_REF_TYPE_INT32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700431 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
432 break;
433 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000434 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700435 }
436
437 return 0;
438}
439
Tai Lya4d748b2023-03-28 22:06:56 +0000440template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700441int OpMinimum<Rank, Dtype>::register_fcn()
442{
443 switch (Dtype)
444 {
Tai Lya4d748b2023-03-28 22:06:56 +0000445 case TOSA_REF_TYPE_FP16:
446 case TOSA_REF_TYPE_BF16:
447 case TOSA_REF_TYPE_FP32:
448 case TOSA_REF_TYPE_FP64:
Jeremy Johnson29b02012024-04-30 13:56:20 +0100449 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType {
450 if (isnan(a))
451 {
452 return a;
453 }
454 else if (isnan(b))
455 {
456 return b;
457 }
458 else
459 {
460 return a < b ? a : b;
461 }
462 };
463 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000464 case TOSA_REF_TYPE_INT32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700465 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
466 break;
467 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000468 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700469 }
470
471 return 0;
472}
473
Tai Lya4d748b2023-03-28 22:06:56 +0000474template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
TatWai Chongc7bfa582024-02-12 16:53:23 -0800475int OpMul<Rank, InDtype, OutDtype>::eval()
476{
477 // All cases except in_out_t == int32_t go to the general binary op workflow.
478 if constexpr (InDtype != TOSA_REF_TYPE_INT32)
479 {
480 return BinaryNode<Rank, InDtype, OutDtype>::eval();
481 }
482 else
483 {
484 std::vector<int> calculated_shape;
485 this->broadcast(calculated_shape);
486
487 auto result_shape = this->result->getShape();
488 ERROR_IF(calculated_shape != result_shape,
489 "Broadcast_shape failure, calculated_shape and result_shape don't match");
490
491 TIn ia = this->a->getTensor().broadcast(this->bcast_a);
492 TIn ib = this->b->getTensor().broadcast(this->bcast_b);
493
494 using TInt64 = Eigen::Tensor<int64_t, Rank>;
495 TInt64 tmp_result = ia.binaryExpr(ib, this->mul_fcn);
496
Jeremy Johnson0a042992024-02-28 13:20:05 +0000497 // Retrieve `shift` value and construct a Eigen tensor instance for it. Shift is stored
498 // as rank-0 tensor in Flatbuffer.
499 auto s0 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank0>*>(this->inputs[2]);
TatWai Chongc7bfa582024-02-12 16:53:23 -0800500
Jeremy Johnson0a042992024-02-28 13:20:05 +0000501 // Get zero element from rank-0 tensor (i.e. shape = (0,)) in Numpy since `class Tensor`
502 // currenly has no knowledge of the size of rank-0 tensor. Store rank-1 tensor instead
503 // for testing.
504 auto s1 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank1>*>(this->inputs[2]);
505
506 ASSERT_MEM(s0 || s1);
507
508 int shift = s0 ? s0->getTensor()(0) : s1->getTensor()(0);
TatWai Chongc7bfa582024-02-12 16:53:23 -0800509 TIn is(ia);
510 is.setConstant(shift);
511
512 TOut result = tmp_result.binaryExpr(is, this->shr_fcn);
513 this->result->getTensor() = result;
514
515 return GraphNode::eval();
516 }
517}
518
519// Eigen operators requires tensor operands meet NumDims > 0, partial specialize
520// this like we did for the base class.
521template <>
522int OpMul<0, TOSA_REF_TYPE_INT32, TOSA_REF_TYPE_INT32>::eval()
523{
524 Eigen::Tensor<int64_t, 0> tmp_result = this->a->getTensor().binaryExpr(this->b->getTensor(), this->mul_fcn);
525
526 // Retrieve `shift` value.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000527 auto s0 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank0>*>(this->inputs[2]);
528 auto s1 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank1>*>(this->inputs[2]);
529 ASSERT_MEM(s0 || s1);
TatWai Chongc7bfa582024-02-12 16:53:23 -0800530
531 Eigen::Tensor<int64_t, 0> shift;
Jeremy Johnson0a042992024-02-28 13:20:05 +0000532 shift.setConstant(s0 ? s0->getTensor()(0) : s1->getTensor()(0));
TatWai Chongc7bfa582024-02-12 16:53:23 -0800533
534 this->result->getTensor() = tmp_result.binaryExpr(shift, this->shr_fcn);
535
536 return GraphNode::eval();
537}
538
539template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700540int OpMul<Rank, InDtype, OutDtype>::register_fcn()
541{
TatWai Chongc7bfa582024-02-12 16:53:23 -0800542 // Register evaluation function for in_out_t == int32_t case first as it supports shift
543 // right to int32_t output.
544 if constexpr (InDtype == TOSA_REF_TYPE_INT32)
545 {
546 // Perform multiplication on int32_t inputs to product int64_t result.
547 this->mul_fcn = [](InEigenType a, InEigenType b) -> int64_t {
548 int64_t result = static_cast<int64_t>(a) * static_cast<int64_t>(b);
549 return result;
550 };
551
552 // Convert data from int64_t to int32_t.
553 this->shr_fcn = [this](int64_t a, InEigenType shift) -> OutEigenType {
554 int64_t result;
555 if (shift > 0)
556 {
557 int64_t round = INT64_C(1) << (shift - 1);
558 result = a + round;
559 result = result >> shift;
560
561 REQUIRE(result >= QMin && result <= QMax,
562 "OpMul: result %" PRId64 " exceeds valid range [%" PRId64 ", %" PRId64 "]", result, QMin, QMax);
563 }
564 else
565 {
566 result = a;
567 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
568 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
569 REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
570 return static_cast<InEigenType>(result);
571 }
572 return static_cast<OutEigenType>(result);
573 };
574
575 return 0;
576 }
Kevin Chengaee1fac2020-11-11 13:54:06 -0800577
Eric Kunzee5e26762020-10-13 16:11:07 -0700578 switch (InDtype)
579 {
Tai Lya4d748b2023-03-28 22:06:56 +0000580 case TOSA_REF_TYPE_FP16:
581 case TOSA_REF_TYPE_BF16:
582 case TOSA_REF_TYPE_FP32:
Eric Kunze17fab3b2023-08-02 18:15:20 +0000583 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a * b); };
Kevin Chengaee1fac2020-11-11 13:54:06 -0800584 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000585 case TOSA_REF_TYPE_FP64:
Eric Kunze17fab3b2023-08-02 18:15:20 +0000586 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
Tai Lya4d748b2023-03-28 22:06:56 +0000587 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000588 case TOSA_REF_TYPE_INT8:
589 case TOSA_REF_TYPE_INT16:
Eric Kunze17fab3b2023-08-02 18:15:20 +0000590 this->fcn = [](InEigenType lhs, InEigenType rhs) -> OutEigenType {
Eric Kunzee5e26762020-10-13 16:11:07 -0700591 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
592
593 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
594
595 return clamped_output;
596 };
597 break;
598 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000599 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700600 }
601
602 return 0;
603}
604
Tai Lya4d748b2023-03-28 22:06:56 +0000605template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700606int OpPow<Rank, Dtype>::register_fcn()
607{
608 switch (Dtype)
609 {
Tai Lya4d748b2023-03-28 22:06:56 +0000610 case TOSA_REF_TYPE_FP16:
611 case TOSA_REF_TYPE_BF16:
612 case TOSA_REF_TYPE_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100613 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(powf(a, b)); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700614 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000615 case TOSA_REF_TYPE_FP64:
Jeremy Johnson9a758382023-11-07 16:27:35 +0000616 if (g_func_config.abs_mode)
617 {
Jeremy Johnsoncbbbafa2024-02-06 11:18:47 +0000618 // ABS_ERROR bounds return 2*(1+abs(log(abs(a))*b))
Jeremy Johnson9a758382023-11-07 16:27:35 +0000619 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType {
620 OutEigenType c = log(a > (InEigenType)0 ? a : (-a)) * b;
Jeremy Johnsoncbbbafa2024-02-06 11:18:47 +0000621 return 2 * (1.0 + (c > (OutEigenType)0 ? c : (-c)));
Jeremy Johnson9a758382023-11-07 16:27:35 +0000622 };
623 }
624 else
625 {
626 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return pow(a, b); };
627 }
Tai Lya4d748b2023-03-28 22:06:56 +0000628 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700629 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000630 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700631 }
632
633 return 0;
634}
635
Tai Lya4d748b2023-03-28 22:06:56 +0000636template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700637int OpSub<Rank, Dtype>::register_fcn()
638{
639 switch (InDtype)
640 {
Tai Lya4d748b2023-03-28 22:06:56 +0000641 case TOSA_REF_TYPE_INT32:
Jeremy Johnson90347472021-09-06 12:04:07 +0100642 this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
643 int64_t res_in_64 = static_cast<int64_t>(a) - b;
644 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
645 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
646 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range");
647 return static_cast<InEigenType>(res_in_64);
648 };
649 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000650 case TOSA_REF_TYPE_FP16:
651 case TOSA_REF_TYPE_BF16:
652 case TOSA_REF_TYPE_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100653 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a - b); };
Eric Kunzee5e26762020-10-13 16:11:07 -0700654 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000655 case TOSA_REF_TYPE_FP64:
656 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
657 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700658 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000659 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700660 }
661
662 return 0;
663}
664
Tai Lya4d748b2023-03-28 22:06:56 +0000665template <int Rank, TOSA_REF_TYPE InDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000666OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700667 : GraphNode(sgt_, Op_TABLE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700668{
TatWai Chong51d880e2024-05-12 02:35:04 -0700669 setRequiredOperands(2, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700670 setRequiredRank(0, 6);
671}
672
Tai Lya4d748b2023-03-28 22:06:56 +0000673template <int Rank, TOSA_REF_TYPE InDtype>
Kevin Cheng571f7182021-05-24 17:20:01 -0700674OpTable<Rank, InDtype>::~OpTable()
TatWai Chong51d880e2024-05-12 02:35:04 -0700675{}
Eric Kunzee5e26762020-10-13 16:11:07 -0700676
Tai Lya4d748b2023-03-28 22:06:56 +0000677template <int Rank, TOSA_REF_TYPE InDtype>
Kevin Cheng571f7182021-05-24 17:20:01 -0700678int OpTable<Rank, InDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700679{
Jerry Gea793f462023-04-11 00:05:02 +0000680 // Check Tosa Level
681 auto tosa_level = g_func_config.tosa_level;
682 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
683
Eric Kunzee5e26762020-10-13 16:11:07 -0700684 if (validateRequiredOperands())
685 return 1;
686
687 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
688 {
689 return 1;
690 }
691
Kevin Chengfe392ce2021-10-18 21:51:55 +0000692 ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
TatWai Chong51d880e2024-05-12 02:35:04 -0700693 ERROR_IF(inputs[1]->getDtype() != TableDtype, "OpTable: Unexpected table type");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000694 ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type");
Kevin Chengfe392ce2021-10-18 21:51:55 +0000695
TatWai Chong51d880e2024-05-12 02:35:04 -0700696 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
697 table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
698 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng571f7182021-05-24 17:20:01 -0700699
Kevin Chengfe392ce2021-10-18 21:51:55 +0000700 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700701
702 return 0;
703}
704
Tai Lya4d748b2023-03-28 22:06:56 +0000705template <int Rank, TOSA_REF_TYPE InDtype>
Kevin Cheng571f7182021-05-24 17:20:01 -0700706int OpTable<Rank, InDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700707{
TatWai Chong51d880e2024-05-12 02:35:04 -0700708 ERROR_IF(this->table->getTensor().size() != TableNumEntries, "OpTable: table tensor size must be %u",
709 TableNumEntries);
Kevin Cheng571f7182021-05-24 17:20:01 -0700710 switch (InDtype)
711 {
Tai Lya4d748b2023-03-28 22:06:56 +0000712 case TOSA_REF_TYPE_INT8:
Kevin Cheng571f7182021-05-24 17:20:01 -0700713 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
714 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
715 int32_t index = input_truncated - QInMin;
TatWai Chong51d880e2024-05-12 02:35:04 -0700716 int32_t value = this->table->getTensor()(index);
Eric Kunzee5e26762020-10-13 16:11:07 -0700717
Kevin Cheng571f7182021-05-24 17:20:01 -0700718 return value;
719 });
720 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000721 case TOSA_REF_TYPE_INT16:
Kevin Cheng571f7182021-05-24 17:20:01 -0700722 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
723 // 1. make sure input is int16 range
724 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700725
Kevin Cheng571f7182021-05-24 17:20:01 -0700726 // 2. calculate index and interpolation fraction
727 int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
728 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
729 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
Eric Kunzee5e26762020-10-13 16:11:07 -0700730
Jerry Ged511f9e2022-08-12 16:12:40 -0700731 // 3. Add REQUIRE CHECK for extreme large/small slopes
TatWai Chong51d880e2024-05-12 02:35:04 -0700732 int32_t base = this->table->getTensor()(index);
733 int32_t next = this->table->getTensor()(index + 1);
Jerry Ged511f9e2022-08-12 16:12:40 -0700734 int32_t slope = next - base;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000735 REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(),
736 "OpTable: slope out of int16_t range");
Jerry Ged511f9e2022-08-12 16:12:40 -0700737
738 // 4. interpolate, generate 16.7 (23-bit) output
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000739 int32_t value = (base << 7) + (slope)*frac;
Kevin Cheng571f7182021-05-24 17:20:01 -0700740
741 return value;
742 });
743 break;
744 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000745 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
Kevin Cheng571f7182021-05-24 17:20:01 -0700746 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700747
748 return GraphNode::eval();
749}
750
751// template explicit instantiation
Jared Smolens98c281f2022-12-20 15:09:25 -0800752DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, FP16);
753DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BF16);
754DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, FP32);
755DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT8, INT8);
756DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT16, INT16);
757DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, INT32);
758DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT8, INT32);
759DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT16, INT32);
760DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BOOL, BOOL);
761DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, BOOL);
762DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BOOL);
763DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, BOOL);
764DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000765DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP64, FP64);
Eric Kunzeedac6ab2023-06-28 13:29:38 -0700766DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP64, BOOL);
Jared Smolens98c281f2022-12-20 15:09:25 -0800767
James Ward8b390432022-08-12 20:48:56 +0100768DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100769DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100770DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700771DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +0000772DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700773
774DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
775DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
776DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
777
Kevin Cheng3a478572021-01-22 17:21:02 -0800778DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700779DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
780DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
781
Kevin Cheng3a478572021-01-22 17:21:02 -0800782DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700783DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
784DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
785
Kevin Cheng3a478572021-01-22 17:21:02 -0800786DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700787DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
788DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
789
Matthew Haddon459443c2021-08-23 16:43:13 +0100790DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32);
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700791
Eric Kunzee5e26762020-10-13 16:11:07 -0700792DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
793
794DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
795DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
796DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
797
798DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
799DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
800DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
801
802DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
803
804DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
805
James Ward8b390432022-08-12 20:48:56 +0100806DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100807DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100808DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700809DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +0000810DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
James Ward8b390432022-08-12 20:48:56 +0100812DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100813DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100814DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700815DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +0000816DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700817
James Ward8b390432022-08-12 20:48:56 +0100818DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100819DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100820DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700821DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
822DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
823DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +0000824DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700825
James Ward8b390432022-08-12 20:48:56 +0100826DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100827DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100828DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +0000829DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700830
James Ward8b390432022-08-12 20:48:56 +0100831DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100832DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100833DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700834DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +0000835DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700836
Kevin Cheng571f7182021-05-24 17:20:01 -0700837DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
838DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700839
James Ward8b390432022-08-12 20:48:56 +0100840// Instantiation of nodes for comparison operators opEqual, opGreater
841// and opGreaterEqual
842DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
James Ward24dbc422022-10-19 12:20:31 +0100843DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100844DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL);
Eric Kunzee5e26762020-10-13 16:11:07 -0700845DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000846DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP64, BOOL);