blob: d07790ee358b10a324c09b95e18eeea941a2375d [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
26BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_)
27 : GraphNode(op_, id_)
28{
29 setRequiredOperands(2, 1);
30 setRequiredRank(0, 6);
31
32 a_rank = b_rank = max_input_rank = -1;
33 a = b = nullptr;
34 a_rank0 = b_rank0 = nullptr;
35 result = nullptr;
36
37 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
38}
39
40template <int Rank, DType InDtype, DType OutDtype>
41BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
42{}
43
44template <int Rank, DType InDtype, DType OutDtype>
45int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
46{
47 if (validateRequiredOperands())
48 return 1;
49
50 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
51 {
52 return 1;
53 }
54
55 a_rank = inputs[0]->getRank();
56 b_rank = inputs[1]->getRank();
57 if (a_rank != 0 && b_rank != 0 && a_rank != b_rank)
58 {
59 printNodeValidationError("Binary operator input ranks must match");
60 return 1;
61 }
62
63 max_input_rank = a_rank >= b_rank ? a_rank : b_rank;
64
65 // A & B must be the same types
66 if (inputs[0]->matchType(*inputs[1]))
67 {
68 printNodeValidationError("Binary operator input types must match");
69 return 1;
70 }
71
72 // Result's geometry must match, but the type may be wider
73 if (outputs[0]->getRank() != max_input_rank)
74 {
75 printNodeValidationError("Binary operator input and output genometry must match");
76 return 1;
77 }
78
79 if (a_rank == max_input_rank)
80 {
81 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
82 }
83 else
84 {
85 a_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[0]);
86 }
87
88 if (b_rank == max_input_rank)
89 {
90 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
91 }
92 else
93 {
94 b_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[1]);
95 }
96
97 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
98
99 // either a or b can be rank0
100 // a_rank0 and b_rank0 can't be valid at the same time.
101 // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0'
102 ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result);
103
104 return 0;
105}
106
107template <int Rank, DType InDtype, DType OutDtype>
108int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
109{
110 auto output_shape = result->getTensor().dimensions();
111
112 std::vector<int> a_shape, b_shape;
113
114 if (a_rank == max_input_rank)
115 {
116 a_shape = a->getShape();
117 }
118 else
119 {
120 a_shape.assign(max_input_rank, 1);
121 }
122
123 if (b_rank == max_input_rank)
124 {
125 b_shape = b->getShape();
126 }
127 else
128 {
129 b_shape.assign(max_input_rank, 1);
130 }
131
132 for (int i = 0; i < max_input_rank; i++)
133 {
134 if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
135 {
136 bcast_a[i] = output_shape[i];
137 }
138 else
139 {
140 bcast_a[i] = 1;
141 }
142 if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
143 {
144 bcast_b[i] = output_shape[i];
145 }
146 else
147 {
148 bcast_b[i] = 1;
149 }
150 }
151
152 return 0;
153}
154
155template <int Rank, DType InDtype, DType OutDtype>
156int BinaryNode<Rank, InDtype, OutDtype>::eval()
157{
158 this->broadcast();
159
160 Eigen::array<int, Rank> reshaper;
161 reshaper.fill(1);
162 TIn ia, ib;
163
164 if (this->a_rank == this->max_input_rank)
165 {
166 ia = this->a->getTensor().broadcast(this->bcast_a);
167 }
168 else
169 {
170 ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a);
171 }
172
173 if (this->b_rank == this->max_input_rank)
174 {
175 ib = this->b->getTensor().broadcast(this->bcast_b);
176 }
177 else
178 {
179 ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b);
180 }
181
182 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
183
184 return GraphNode::eval();
185}
186
187// still need to partial specialize this, or Eigen will throw static assertion
188template <DType InDtype, DType OutDtype>
189int BinaryNode<0, InDtype, OutDtype>::eval()
190{
191 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
192
193 return GraphNode::eval();
194}
195
196template <int Rank, DType Dtype>
197int OpAdd<Rank, Dtype>::register_fcn()
198{
199 switch (InDtype)
200 {
201 case DType_FLOAT:
202 case DType_INT32:
203 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
204 break;
205 default:
206 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
207 }
208
209 return 0;
210}
211
212template <int Rank, DType Dtype>
213int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
214{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800215 bool round = attribute->round();
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 int32_t num_bits = 0;
217 switch (Dtype)
218 {
219 case DType_INT8:
220 num_bits = 8;
221 break;
222 case DType_INT16:
223 num_bits = 16;
224 break;
225 case DType_INT32:
226 num_bits = 32;
227 break;
228 default:
229 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
230 }
231
Kevin Chengaee1fac2020-11-11 13:54:06 -0800232 this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
233 ASSERT_MSG_NODE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
234 (int32_t)b, num_bits);
235
236 InEigenType acc = a >> b;
237
238 if (round && b > 0 && (a >> (b - 1) & 1) != 0)
239 {
240 acc++;
241 }
242
243 return acc;
Eric Kunzee5e26762020-10-13 16:11:07 -0700244 };
245
246 return 0;
247}
248
249template <int Rank, DType Dtype>
250int OpBitwiseAnd<Rank, Dtype>::register_fcn()
251{
252 switch (Dtype)
253 {
254 case DType_AINT8:
255 case DType_INT16:
256 case DType_INT32:
257 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
258 break;
259 default:
260 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
261 }
262
263 return 0;
264}
265
266template <int Rank, DType Dtype>
267int OpBitwiseOr<Rank, Dtype>::register_fcn()
268{
269 switch (Dtype)
270 {
271 case DType_AINT8:
272 case DType_INT16:
273 case DType_INT32:
274 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
275 break;
276 default:
277 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
278 }
279
280 return 0;
281}
282
283template <int Rank, DType Dtype>
284int OpBitwiseXor<Rank, Dtype>::register_fcn()
285{
286 switch (Dtype)
287 {
288 case DType_AINT8:
289 case DType_INT16:
290 case DType_INT32:
291 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
292 break;
293 default:
294 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
295 }
296
297 return 0;
298}
299
300template <int Rank, DType Dtype>
301int OpLogicalAnd<Rank, Dtype>::register_fcn()
302{
303 switch (Dtype)
304 {
305 case DType_BOOL:
306 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
307 break;
308 default:
309 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
310 }
311
312 return 0;
313}
314
315template <int Rank, DType Dtype>
316int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
317{
318 switch (Dtype)
319 {
320 case DType_INT8:
321 case DType_INT16:
322 case DType_INT32:
323 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
324 break;
325 default:
326 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
327 }
328
329 return 0;
330}
331
332template <int Rank, DType Dtype>
333int OpLogicalRightShift<Rank, Dtype>::register_fcn()
334{
335 int32_t num_bits = 0;
336 switch (Dtype)
337 {
338 case DType_INT8:
339 num_bits = 8;
340 break;
341 case DType_INT16:
342 num_bits = 16;
343 break;
344 case DType_INT32:
345 num_bits = 32;
346 break;
347 default:
348 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
349 }
350
351 this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
352 uint32_t mask = ONES_MASK(num_bits) >> b;
353 return (a >> b) & mask;
354 };
355
356 return 0;
357}
358
359template <int Rank, DType Dtype>
360int OpLogicalOr<Rank, Dtype>::register_fcn()
361{
362 switch (Dtype)
363 {
364 case DType_BOOL:
365 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
366 break;
367 default:
368 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
369 }
370
371 return 0;
372}
373
374template <int Rank, DType Dtype>
375int OpLogicalXor<Rank, Dtype>::register_fcn()
376{
377 switch (Dtype)
378 {
379 case DType_BOOL:
380 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
381 break;
382 default:
383 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
384 }
385
386 return 0;
387}
388
389template <int Rank, DType Dtype>
390int OpMaximum<Rank, Dtype>::register_fcn()
391{
392 switch (Dtype)
393 {
394 case DType_FLOAT:
395 case DType_INT32:
396 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
397 break;
398 default:
399 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
400 }
401
402 return 0;
403}
404
405template <int Rank, DType Dtype>
406int OpMinimum<Rank, Dtype>::register_fcn()
407{
408 switch (Dtype)
409 {
410 case DType_FLOAT:
411 case DType_INT32:
412 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
413 break;
414 default:
415 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
416 }
417
418 return 0;
419}
420
421template <int Rank, DType InDtype, DType OutDtype>
422int OpMul<Rank, InDtype, OutDtype>::register_fcn()
423{
Kevin Chengaee1fac2020-11-11 13:54:06 -0800424 int32_t shift = attribute->shift();
425 ASSERT_MSG_NODE(InDtype == DType_INT32 || shift == 0, "OpMul: shift needs to be 0 but is %d if input is %s", shift,
426 EnumNamesDType()[InDtype]);
427
Eric Kunzee5e26762020-10-13 16:11:07 -0700428 switch (InDtype)
429 {
430 case DType_FLOAT:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800431 this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
432 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700433 case DType_INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -0800434 this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
435 int64_t result;
436 if (shift > 0)
437 {
438 int64_t round = 1L << (shift - 1);
439 result = a * b + round;
440 result = result >> shift;
441
442 ASSERT_MSG_NODE(result >= QMin && result <= QMax,
443 "OpMul: result %ld exceeds valid range [%ld, %ld]", result, QMin, QMax);
444 }
445 else
446 {
447 result = a * b;
448 }
449
450 return static_cast<OutEigenType>(result);
451 };
Eric Kunzee5e26762020-10-13 16:11:07 -0700452 break;
453 case DType_INT8:
454 case DType_INT16:
455 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
456 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
457
458 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
459
460 return clamped_output;
461 };
462 break;
463 default:
464 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
465 }
466
467 return 0;
468}
469
470template <int Rank, DType Dtype>
471int OpPow<Rank, Dtype>::register_fcn()
472{
473 switch (Dtype)
474 {
475 case DType_FLOAT:
476 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
477 break;
478 default:
479 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
480 }
481
482 return 0;
483}
484
485template <int Rank, DType Dtype>
486int OpSub<Rank, Dtype>::register_fcn()
487{
488 switch (InDtype)
489 {
490 case DType_FLOAT:
491 case DType_INT32:
492 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
493 break;
494 default:
495 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
496 }
497
498 return 0;
499}
500
501template <int Rank>
502OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
503 : GraphNode(Op_TABLE, id_)
504{
505 setRequiredOperands(2, 1);
506 setRequiredRank(0, 6);
507}
508
509template <int Rank>
510OpTable<Rank>::~OpTable()
511{}
512
513template <int Rank>
514int OpTable<Rank>::checkTensorAttributes()
515{
516 if (validateRequiredOperands())
517 return 1;
518
519 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
520 {
521 return 1;
522 }
523
524 if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
525 {
526 FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries");
527 return 1;
528 }
529
530 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
531 table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
532 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
533
534 ASSERT_MEM(in && table && out);
535
536 return 0;
537}
538
539template <int Rank>
540int OpTable<Rank>::eval()
541{
542 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
543 // 1. make sure input is int16 range
544 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
545
546 // 2. calculate index and interpolation fraction
547 int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1));
548 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
549 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
550
551 // 3. interpolate, generate 16.7 (23-bit) output
552 int32_t base = this->table->getTensor()(index);
553 int32_t next = this->table->getTensor()(index + 1);
554 int32_t value = (base << 7) + (next - base) * frac;
555
556 return value;
557 });
558
559 return GraphNode::eval();
560}
561
562// template explicit instantiation
563DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
564DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
565
566DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
567DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
568DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
569
570DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8);
571DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
572DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
573
574DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8);
575DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
576DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
577
578DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8);
579DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
580DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
581
582DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
583
584DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
585DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
586DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
587
588DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
589DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
590DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
591
592DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
593
594DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
595
596DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
597DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
598
599DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
600DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
601
602DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
603DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
604DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
605DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
606
607DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
608
609DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
610DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
611
612DEF_INSTANTIATE_ONE_RANK_0_6(OpTable);
613
614DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
615DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);