blob: 4d4f8b9fcd8aefb58fc0e2fce059f8149fc02304 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_binary.h"
17#include "arith_util.h"
18#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
26BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_)
27 : GraphNode(op_, id_)
28{
29 setRequiredOperands(2, 1);
30 setRequiredRank(0, 6);
31
32 a_rank = b_rank = max_input_rank = -1;
33 a = b = nullptr;
34 a_rank0 = b_rank0 = nullptr;
35 result = nullptr;
36
37 fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
38}
39
40template <int Rank, DType InDtype, DType OutDtype>
41BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
42{}
43
44template <int Rank, DType InDtype, DType OutDtype>
45int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
46{
47 if (validateRequiredOperands())
48 return 1;
49
50 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
51 {
52 return 1;
53 }
54
55 a_rank = inputs[0]->getRank();
56 b_rank = inputs[1]->getRank();
57 if (a_rank != 0 && b_rank != 0 && a_rank != b_rank)
58 {
59 printNodeValidationError("Binary operator input ranks must match");
60 return 1;
61 }
62
63 max_input_rank = a_rank >= b_rank ? a_rank : b_rank;
64
65 // A & B must be the same types
66 if (inputs[0]->matchType(*inputs[1]))
67 {
68 printNodeValidationError("Binary operator input types must match");
69 return 1;
70 }
71
72 // Result's geometry must match, but the type may be wider
73 if (outputs[0]->getRank() != max_input_rank)
74 {
75 printNodeValidationError("Binary operator input and output genometry must match");
76 return 1;
77 }
78
79 if (a_rank == max_input_rank)
80 {
81 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
82 }
83 else
84 {
85 a_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[0]);
86 }
87
88 if (b_rank == max_input_rank)
89 {
90 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
91 }
92 else
93 {
94 b_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[1]);
95 }
96
97 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
98
99 // either a or b can be rank0
100 // a_rank0 and b_rank0 can't be valid at the same time.
101 // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0'
102 ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result);
103
104 return 0;
105}
106
107template <int Rank, DType InDtype, DType OutDtype>
108int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
109{
110 auto output_shape = result->getTensor().dimensions();
111
112 std::vector<int> a_shape, b_shape;
113
114 if (a_rank == max_input_rank)
115 {
116 a_shape = a->getShape();
117 }
118 else
119 {
120 a_shape.assign(max_input_rank, 1);
121 }
122
123 if (b_rank == max_input_rank)
124 {
125 b_shape = b->getShape();
126 }
127 else
128 {
129 b_shape.assign(max_input_rank, 1);
130 }
131
132 for (int i = 0; i < max_input_rank; i++)
133 {
134 if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
135 {
136 bcast_a[i] = output_shape[i];
137 }
138 else
139 {
140 bcast_a[i] = 1;
141 }
142 if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
143 {
144 bcast_b[i] = output_shape[i];
145 }
146 else
147 {
148 bcast_b[i] = 1;
149 }
150 }
151
152 return 0;
153}
154
155template <int Rank, DType InDtype, DType OutDtype>
156int BinaryNode<Rank, InDtype, OutDtype>::eval()
157{
158 this->broadcast();
159
160 Eigen::array<int, Rank> reshaper;
161 reshaper.fill(1);
162 TIn ia, ib;
163
164 if (this->a_rank == this->max_input_rank)
165 {
166 ia = this->a->getTensor().broadcast(this->bcast_a);
167 }
168 else
169 {
170 ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a);
171 }
172
173 if (this->b_rank == this->max_input_rank)
174 {
175 ib = this->b->getTensor().broadcast(this->bcast_b);
176 }
177 else
178 {
179 ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b);
180 }
181
182 this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
183
184 return GraphNode::eval();
185}
186
187// still need to partial specialize this, or Eigen will throw static assertion
188template <DType InDtype, DType OutDtype>
189int BinaryNode<0, InDtype, OutDtype>::eval()
190{
191 this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
192
193 return GraphNode::eval();
194}
195
196template <int Rank, DType Dtype>
197int OpAdd<Rank, Dtype>::register_fcn()
198{
199 switch (InDtype)
200 {
201 case DType_FLOAT:
202 case DType_INT32:
203 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
204 break;
205 default:
206 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
207 }
208
209 return 0;
210}
211
212template <int Rank, DType Dtype>
213int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
214{
215 int32_t num_bits = 0;
216 switch (Dtype)
217 {
218 case DType_INT8:
219 num_bits = 8;
220 break;
221 case DType_INT16:
222 num_bits = 16;
223 break;
224 case DType_INT32:
225 num_bits = 32;
226 break;
227 default:
228 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
229 }
230
231 this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
232 uint32_t sign = a & (1 << (num_bits - 1));
233 uint32_t ones_mask = ONES_MASK(b) << (num_bits - b);
234 if (sign)
235 return ones_mask | (a >> b);
236 else
237 return (~ones_mask) & (a >> b);
238 };
239
240 return 0;
241}
242
243template <int Rank, DType Dtype>
244int OpBitwiseAnd<Rank, Dtype>::register_fcn()
245{
246 switch (Dtype)
247 {
248 case DType_AINT8:
249 case DType_INT16:
250 case DType_INT32:
251 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
252 break;
253 default:
254 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
255 }
256
257 return 0;
258}
259
260template <int Rank, DType Dtype>
261int OpBitwiseOr<Rank, Dtype>::register_fcn()
262{
263 switch (Dtype)
264 {
265 case DType_AINT8:
266 case DType_INT16:
267 case DType_INT32:
268 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
269 break;
270 default:
271 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
272 }
273
274 return 0;
275}
276
277template <int Rank, DType Dtype>
278int OpBitwiseXor<Rank, Dtype>::register_fcn()
279{
280 switch (Dtype)
281 {
282 case DType_AINT8:
283 case DType_INT16:
284 case DType_INT32:
285 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
286 break;
287 default:
288 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
289 }
290
291 return 0;
292}
293
294template <int Rank, DType Dtype>
295int OpLogicalAnd<Rank, Dtype>::register_fcn()
296{
297 switch (Dtype)
298 {
299 case DType_BOOL:
300 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
301 break;
302 default:
303 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
304 }
305
306 return 0;
307}
308
309template <int Rank, DType Dtype>
310int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
311{
312 switch (Dtype)
313 {
314 case DType_INT8:
315 case DType_INT16:
316 case DType_INT32:
317 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
318 break;
319 default:
320 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
321 }
322
323 return 0;
324}
325
326template <int Rank, DType Dtype>
327int OpLogicalRightShift<Rank, Dtype>::register_fcn()
328{
329 int32_t num_bits = 0;
330 switch (Dtype)
331 {
332 case DType_INT8:
333 num_bits = 8;
334 break;
335 case DType_INT16:
336 num_bits = 16;
337 break;
338 case DType_INT32:
339 num_bits = 32;
340 break;
341 default:
342 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
343 }
344
345 this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
346 uint32_t mask = ONES_MASK(num_bits) >> b;
347 return (a >> b) & mask;
348 };
349
350 return 0;
351}
352
353template <int Rank, DType Dtype>
354int OpLogicalOr<Rank, Dtype>::register_fcn()
355{
356 switch (Dtype)
357 {
358 case DType_BOOL:
359 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
360 break;
361 default:
362 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
363 }
364
365 return 0;
366}
367
368template <int Rank, DType Dtype>
369int OpLogicalXor<Rank, Dtype>::register_fcn()
370{
371 switch (Dtype)
372 {
373 case DType_BOOL:
374 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
375 break;
376 default:
377 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
378 }
379
380 return 0;
381}
382
383template <int Rank, DType Dtype>
384int OpMaximum<Rank, Dtype>::register_fcn()
385{
386 switch (Dtype)
387 {
388 case DType_FLOAT:
389 case DType_INT32:
390 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
391 break;
392 default:
393 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
394 }
395
396 return 0;
397}
398
399template <int Rank, DType Dtype>
400int OpMinimum<Rank, Dtype>::register_fcn()
401{
402 switch (Dtype)
403 {
404 case DType_FLOAT:
405 case DType_INT32:
406 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
407 break;
408 default:
409 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
410 }
411
412 return 0;
413}
414
415template <int Rank, DType InDtype, DType OutDtype>
416int OpMul<Rank, InDtype, OutDtype>::register_fcn()
417{
418 switch (InDtype)
419 {
420 case DType_FLOAT:
421 case DType_INT32:
422 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
423 break;
424 case DType_INT8:
425 case DType_INT16:
426 this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
427 OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
428
429 OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
430
431 return clamped_output;
432 };
433 break;
434 default:
435 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
436 }
437
438 return 0;
439}
440
441template <int Rank, DType Dtype>
442int OpPow<Rank, Dtype>::register_fcn()
443{
444 switch (Dtype)
445 {
446 case DType_FLOAT:
447 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
448 break;
449 default:
450 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
451 }
452
453 return 0;
454}
455
456template <int Rank, DType Dtype>
457int OpSub<Rank, Dtype>::register_fcn()
458{
459 switch (InDtype)
460 {
461 case DType_FLOAT:
462 case DType_INT32:
463 this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
464 break;
465 default:
466 FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
467 }
468
469 return 0;
470}
471
472template <int Rank>
473OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
474 : GraphNode(Op_TABLE, id_)
475{
476 setRequiredOperands(2, 1);
477 setRequiredRank(0, 6);
478}
479
480template <int Rank>
481OpTable<Rank>::~OpTable()
482{}
483
484template <int Rank>
485int OpTable<Rank>::checkTensorAttributes()
486{
487 if (validateRequiredOperands())
488 return 1;
489
490 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
491 {
492 return 1;
493 }
494
495 if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
496 {
497 FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries");
498 return 1;
499 }
500
501 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
502 table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
503 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
504
505 ASSERT_MEM(in && table && out);
506
507 return 0;
508}
509
510template <int Rank>
511int OpTable<Rank>::eval()
512{
513 this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
514 // 1. make sure input is int16 range
515 int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
516
517 // 2. calculate index and interpolation fraction
518 int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1));
519 index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
520 int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
521
522 // 3. interpolate, generate 16.7 (23-bit) output
523 int32_t base = this->table->getTensor()(index);
524 int32_t next = this->table->getTensor()(index + 1);
525 int32_t value = (base << 7) + (next - base) * frac;
526
527 return value;
528 });
529
530 return GraphNode::eval();
531}
532
533// template explicit instantiation
534DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
535DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
536
537DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
538DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
539DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
540
541DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8);
542DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
543DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
544
545DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8);
546DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
547DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
548
549DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8);
550DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
551DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
552
553DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
554
555DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
556DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
557DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
558
559DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
560DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
561DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
562
563DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
564
565DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
566
567DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
568DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
569
570DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
571DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
572
573DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
574DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
575DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
576DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
577
578DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
579
580DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
581DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
582
583DEF_INSTANTIATE_ONE_RANK_0_6(OpTable);
584
585DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
586DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);