blob: c1a14b70501b9cc90787b07b8375d20b8154d17c [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "data_layout.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
23template <int Rank, DType Dtype>
24OpConcat<Rank, Dtype>::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
25 : GraphNode(Op_CONCAT, id_)
26{
27 setRequiredOperands(2, 1);
28 setRequiredRank(1, 6);
29
30 INIT_ATTRIBUTE(Axis);
31}
32
33template <int Rank, DType Dtype>
34OpConcat<Rank, Dtype>::~OpConcat()
35{
36 if (attribute)
37 delete attribute;
38}
39
40template <int Rank, DType Dtype>
41int OpConcat<Rank, Dtype>::checkTensorAttributes()
42{
43 if (validateRequiredOperands())
44 return 1;
45
46 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
47 {
48 return 1;
49 }
50
51 // output and input must be the same types and rank
52 // inputs[0] and inputs[1] should also match type and rank
53 if (inputs[0]->matchRankType(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]))
54 {
55 printNodeValidationError("Concat operator input ranks and types must match");
56 return 1;
57 }
58
59 lhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
60 rhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
61 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
62
63 if (attribute->axis() < 0 || (size_t)attribute->axis() >= rhs->getShape().size())
64 {
65 printNodeValidationError("Axis is beyond input tensor rank");
66 return 1;
67 }
68
69 return 0;
70}
71
72template <int Rank, DType Dtype>
73int OpConcat<Rank, Dtype>::eval()
74{
75
76 int32_t reversed_axis = Rank - 1 - attribute->axis();
77
78 for (int32_t d = 0; d < Rank; d++)
79 {
80 reverser[d] = Rank - 1 - d;
81 }
82
83 TIn lhs_reversed = lhs->getTensor().shuffle(reverser);
84 TIn rhs_reversed = rhs->getTensor().shuffle(reverser);
85
86 TIn reversed_result = lhs_reversed.concatenate(rhs_reversed, reversed_axis);
87 out->getTensor() = reversed_result.shuffle(reverser);
88 // out->getTensor() = lhs->getTensor().concatenate(rhs->getTensor(), axis);
89
90 return GraphNode::eval();
91}
92
93template <int Rank, DType Dtype>
94OpPad<Rank, Dtype>::OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
95 : GraphNode(Op_PAD, id_)
96{
97 setRequiredOperands(2, 1);
98 setRequiredRank(0, 6);
99
100 INIT_QINFO(Pad);
101}
102
103template <int Rank, DType Dtype>
104OpPad<Rank, Dtype>::~OpPad()
105{
106 if (qinfo)
107 delete qinfo;
108}
109
110template <int Rank, DType Dtype>
111int OpPad<Rank, Dtype>::checkTensorAttributes()
112{
113 if (validateRequiredOperands())
114 return 1;
115
116 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
117 {
118 return 1;
119 }
120
121 // output and input must be the same types
122 if (inputs[0]->matchRankType(*outputs[0]))
123 {
124 printNodeValidationError("Failure to match input and output type and rank");
125 return 1;
126 }
127
128 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
129 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
130 TosaReference::TensorTemplate<ETensor2<int32_t>>* paddings =
131 dynamic_cast<TosaReference::TensorTemplate<ETensor2<int32_t>>*>(inputs[1]);
132
133 for (int i = 0; i < Rank; i++)
134 {
135 paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1));
136 }
137
138 return 0;
139}
140
141template <int Rank, DType Dtype>
142int OpPad<Rank, Dtype>::eval()
143{
144 InEigenType pad_value = 0;
145 if (this->qinfo)
146 {
147 pad_value = (InEigenType)this->qinfo->input_zp();
148 }
149
150 this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
151
152 return GraphNode::eval();
153}
154
155template <int InRank, int OutRank, DType Dtype>
156OpReshape<InRank, OutRank, Dtype>::OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
157 : GraphNode(Op_RESHAPE, id_)
158{
159 setRequiredOperands(1, 1);
160 setRequiredRank(0, 6);
161
162 INIT_ATTRIBUTE(Reshape);
163}
164
165template <int InRank, int OutRank, DType Dtype>
166OpReshape<InRank, OutRank, Dtype>::~OpReshape()
167{
168 if (attribute)
169 delete attribute;
170}
171
172template <int InRank, int OutRank, DType Dtype>
173int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
174{
175 uint32_t minusOneCount = 0;
176
177 if (validateRequiredOperands())
178 return 1;
179
180 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
181 {
182 return 1;
183 }
184
185 // output and input must be the same types
186 if (inputs[0]->matchType(*outputs[0]))
187 {
188 printNodeValidationError("OpReshape: Input and output types must match");
189 return 1;
190 }
191
192 for (uint32_t d = 0; d < OutRank; d++)
193 {
194 if (attribute->shape()[d] == -1)
195 {
196 minusOneCount++;
197 }
198 }
199
200 if (minusOneCount > 1)
201 {
202 printNodeValidationError("OpReshape: new shape has more than one -1 dimension");
203 return 1;
204 }
205
206 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
207 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
208
209 return 0;
210}
211
212template <int InRank, int OutRank, DType Dtype>
213int OpReshape<InRank, OutRank, Dtype>::eval()
214{
215 uint32_t remainingSize = in->getElementCount();
216
217 // If there is a -1 dimension, find the remainder in one pass over the output shape
218 for (int32_t d = 0; d < OutRank; d++)
219 {
220 if (attribute->shape()[d] != -1)
221 {
222 remainingSize = remainingSize / attribute->shape()[d];
223 }
224 }
225
226 for (int32_t d = 0; d < OutRank; d++)
227 {
228 array_shape[d] = attribute->shape()[OutRank - 1 - d];
229 out_reverser[d] = OutRank - 1 - d;
230
231 // Jam in the remainder here
232 if (array_shape[d] == -1)
233 {
234 array_shape[d] = remainingSize;
235 }
236 }
237
238 for (int32_t d = 0; d < InRank; d++)
239 {
240 in_reverser[d] = InRank - 1 - d;
241 }
242
243 // Eigen Tensor is col-major, and we're referencing row-major result
244 // need to reverse it to row-major before reshape, and perform another reverse afterward
245
246 // input tensor rank 0 can't do .shuffle(), need to be handled otherwise
247 TIn in_reversed;
248 if (InRank > 1)
249 {
250 in_reversed = in->getTensor().shuffle(in_reverser);
251 }
252 else
253 {
254 in_reversed = in->getTensor();
255 }
256
257 TOut in_reshaped = in_reversed.reshape(array_shape);
258
259 // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise
260 if (OutRank > 1)
261 {
262 out->getTensor() = in_reshaped.shuffle(out_reverser);
263 }
264 else
265 {
266 out->getTensor() = in_reshaped;
267 }
268
269 return GraphNode::eval();
270}
271
272template <int Rank, DType Dtype>
273OpReverse<Rank, Dtype>::OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
274 : GraphNode(Op_REVERSE, id_)
275{
276 setRequiredOperands(1, 1);
277 setRequiredRank(1, 6);
278
279 INIT_ATTRIBUTE(Axis);
280}
281
282template <int Rank, DType Dtype>
283OpReverse<Rank, Dtype>::~OpReverse()
284{
285 if (attribute)
286 delete attribute;
287}
288
289template <int Rank, DType Dtype>
290int OpReverse<Rank, Dtype>::checkTensorAttributes()
291{
292 if (validateRequiredOperands())
293 return 1;
294
295 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
296 {
297 return 1;
298 }
299
300 // output and input must be the same types
301 if (inputs[0]->matchRankTypeShape(*outputs[0]))
302 {
303 printNodeValidationError("Failure to match input and output rank/type/shape");
304 return 1;
305 }
306
307 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
308 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
309
310 ASSERT_MEM(in && out);
311
312 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
313 {
314 printNodeValidationError("Reverse axis must between [0, input_rank - 1]");
315 return 1;
316 }
317
318 // transform list of axis into true or false list
319 // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false]
320 for (int i = 0; i < Rank; i++)
321 {
322 reverse_array[i] = false;
323 }
324 reverse_array[attribute->axis()] = true;
325
326 return 0;
327}
328
329template <int Rank, DType Dtype>
330int OpReverse<Rank, Dtype>::eval()
331{
332 out->getTensor() = in->getTensor().reverse(reverse_array);
333
334 return GraphNode::eval();
335}
336
337template <int Rank, DType Dtype>
338OpSlice<Rank, Dtype>::OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
339 : GraphNode(Op_SLICE, id_)
340{
341 setRequiredOperands(1, 1);
342 setRequiredRank(0, 6);
343
344 INIT_ATTRIBUTE(Slice);
345}
346
347template <int Rank, DType Dtype>
348OpSlice<Rank, Dtype>::~OpSlice()
349{
350 if (attribute)
351 delete attribute;
352}
353
354template <int Rank, DType Dtype>
355int OpSlice<Rank, Dtype>::checkTensorAttributes()
356{
357 if (validateRequiredOperands())
358 return 1;
359
360 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
361 {
362 return 1;
363 }
364
365 // output and input must be the same types
366 if (inputs[0]->matchType(*outputs[0]))
367 {
368 printNodeValidationError("Failure to match input and output type");
369 return 1;
370 }
371
372 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
373 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
374
375 for (size_t i = 0; i < attribute->begin().size(); i++)
376 {
377 begin_array[i] = attribute->begin()[i];
378 }
379
380 for (size_t i = 0; i < attribute->size().size(); i++)
381 {
382 if (attribute->size()[i] != 0)
383 {
384 size_array[i] = attribute->size()[i];
385 }
386 else
387 {
388 // Tensorflow assigns a zero size to dimensions that are kept
389 // Eigen expects size to be the full size of the dimension
390 size_array[i] = in->getTensor().dimension(0);
391 }
392 }
393
394 return 0;
395}
396
397template <int Rank, DType Dtype>
398int OpSlice<Rank, Dtype>::eval()
399{
400 out->getTensor() = in->getTensor().slice(begin_array, size_array);
401
402 return GraphNode::eval();
403}
404
405template <int Rank, DType Dtype>
406OpTileBase<Rank, Dtype>::OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
407 : GraphNode(Op_TILE, id_)
408{
409 setRequiredOperands(1, 1);
410 setRequiredRank(0, 6);
411
412 INIT_ATTRIBUTE(Tile);
413}
414
415template <int Rank, DType Dtype>
416OpTileBase<Rank, Dtype>::~OpTileBase()
417{
418 if (attribute)
419 delete attribute;
420}
421
422template <int Rank, DType Dtype>
423int OpTileBase<Rank, Dtype>::checkTensorAttributes()
424{
425 if (validateRequiredOperands())
426 return 1;
427
428 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
429 {
430 return 1;
431 }
432
433 // output and input must be the same ranks and types
434 if (inputs[0]->matchRankType(*outputs[0]))
435 {
436 printNodeValidationError("Failure to match input and output rank or type");
437 return 1;
438 }
439
440 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
441 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
442
443 if (attribute->multiples().size() != Rank)
444 {
445 printNodeValidationError("1D list 'multiples' must have size equal to input rank");
446 return 1;
447 }
448
449 for (int32_t d = 0; d < Rank; d++)
450 {
451 if (in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d])
452 {
453 printNodeValidationError("unexpected output shape");
454 return 1;
455 }
456 }
457
458 return 0;
459}
460
461template <int Rank, DType Dtype>
462int OpTile<Rank, Dtype>::eval()
463{
464 // primary template shouldn't be called
465 FATAL_ERROR_NODE("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
466}
467
468template <DType Dtype>
469int OpTile<1, Dtype>::eval()
470{
471 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
472 {
473 int32_t id0 = od0 % this->in->getShape()[0];
474 this->out->getTensor()(od0) = this->in->getTensor()(id0);
475 }
476
477 return GraphNode::eval();
478}
479
480template <DType Dtype>
481int OpTile<2, Dtype>::eval()
482{
483 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
484 {
485 int32_t id0 = od0 % this->in->getShape()[0];
486 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
487 {
488 int32_t id1 = od1 % this->in->getShape()[1];
489 this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1);
490 }
491 }
492
493 return GraphNode::eval();
494}
495
496template <DType Dtype>
497int OpTile<3, Dtype>::eval()
498{
499 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
500 {
501 int32_t id0 = od0 % this->in->getShape()[0];
502 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
503 {
504 int32_t id1 = od1 % this->in->getShape()[1];
505 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
506 {
507 int32_t id2 = od2 % this->in->getShape()[2];
508 this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2);
509 }
510 }
511 }
512
513 return GraphNode::eval();
514}
515
516template <DType Dtype>
517int OpTile<4, Dtype>::eval()
518{
519 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
520 {
521 int32_t id0 = od0 % this->in->getShape()[0];
522 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
523 {
524 int32_t id1 = od1 % this->in->getShape()[1];
525 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
526 {
527 int32_t id2 = od2 % this->in->getShape()[2];
528 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
529 {
530 int32_t id3 = od3 % this->in->getShape()[3];
531 this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3);
532 }
533 }
534 }
535 }
536
537 return GraphNode::eval();
538}
539
540template <int Rank, DType Dtype>
541OpTranspose<Rank, Dtype>::OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
542 : GraphNode(Op_TRANSPOSE, id_)
543{
544 setRequiredOperands(2, 1);
545 setRequiredRank(0, 6);
546}
547
548template <int Rank, DType Dtype>
549OpTranspose<Rank, Dtype>::~OpTranspose()
550{}
551
552template <int Rank, DType Dtype>
553int OpTranspose<Rank, Dtype>::checkTensorAttributes()
554{
555 if (validateRequiredOperands())
556 return 1;
557
558 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
559 {
560 return 1;
561 }
562
563 // output and input must be the same types
564 if (inputs[0]->matchRankType(*outputs[0]))
565 {
566 printNodeValidationError("Failure to match input and output rank and type");
567 return 1;
568 }
569
570 if (inputs[0]->getElementCount() != outputs[0]->getElementCount())
571 {
572 printNodeValidationError("Failure to match input and output total element count");
573 return 1;
574 }
575
576 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
577 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
578 perm_tensor = dynamic_cast<TosaReference::TensorTemplate<ETensor1<int32_t>>*>(inputs[1]);
579
580 return 0;
581}
582
583template <int Rank, DType Dtype>
584int OpTranspose<Rank, Dtype>::eval()
585{
586 for (int32_t d = 0; d < Rank; d++)
587 {
588 perm_array[d] = this->perm_tensor->getTensor().data()[d];
589 }
590
591 out->getTensor() = in->getTensor().shuffle(perm_array);
592
593 return GraphNode::eval();
594}
595
596// template explicit instantiation
597DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT)
Eric Kunzee5e26762020-10-13 16:11:07 -0700598DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
599DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
600DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
601DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
602
603DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700604DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
605DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
606DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
607DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
608
609DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700610DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
611DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
612DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
613DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
614
615DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700616DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
617DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
618DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
619DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
620
621DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700622DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
623DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
624DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
625DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
626
627DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700628DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
629DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
630DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
631DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
632
633DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700634DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
635DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
636DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
637DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);