blob: 05a11e01f319668af842e168970f94c4287ec952 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "data_layout.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
23template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070024OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
25 TosaAttributeBase* attribute_,
26 TosaQuantInfoBase* qinfo_,
27 uint64_t id_)
28 : GraphNode(sgt_, Op_CONCAT, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070029{
Kevin Chengad15dfa2021-03-04 15:15:03 -080030 setRequiredOperands(-1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070031 setRequiredRank(1, 6);
32
33 INIT_ATTRIBUTE(Axis);
34}
35
36template <int Rank, DType Dtype>
37OpConcat<Rank, Dtype>::~OpConcat()
38{
39 if (attribute)
40 delete attribute;
41}
42
43template <int Rank, DType Dtype>
44int OpConcat<Rank, Dtype>::checkTensorAttributes()
45{
46 if (validateRequiredOperands())
47 return 1;
48
Kevin Chengad15dfa2021-03-04 15:15:03 -080049 if (inputs.empty())
Eric Kunzee5e26762020-10-13 16:11:07 -070050 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080051 printNodeValidationError("Concat operator must have at least one input tensor");
Eric Kunzee5e26762020-10-13 16:11:07 -070052 return 1;
53 }
Kevin Chengcc61be32021-10-14 17:09:57 -070054
55 int32_t num_inputs = inputs.size();
56
Eric Kunzee5e26762020-10-13 16:11:07 -070057 // output and input must be the same types and rank
Kevin Chengcc61be32021-10-14 17:09:57 -070058 for (int32_t i = 0; i < num_inputs; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070059 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080060 if (inputs[i]->matchRankType(*outputs[0]))
61 {
Kevin Chengcc61be32021-10-14 17:09:57 -070062 printNodeValidationError("OpConcat: input ranks and types must match");
Kevin Chengad15dfa2021-03-04 15:15:03 -080063 return 1;
64 }
65 ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
Eric Kunzee5e26762020-10-13 16:11:07 -070066 }
67
Kevin Chengcc61be32021-10-14 17:09:57 -070068 if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
Eric Kunzee5e26762020-10-13 16:11:07 -070069 {
Kevin Chengcc61be32021-10-14 17:09:57 -070070 printNodeValidationError("OpConcat: axis is beyond output tensor rank");
Eric Kunzee5e26762020-10-13 16:11:07 -070071 return 1;
72 }
73
Kevin Chengcc61be32021-10-14 17:09:57 -070074 int32_t output_dim_on_axis = 0;
75 for (int32_t j = 0; j < num_inputs; j++)
76 {
77 for (int32_t i = 0; i < Rank; i++)
78 {
79 int32_t input_dim = inputs[j]->getShape()[i];
80 if (i == attribute->axis())
81 {
82 output_dim_on_axis += input_dim;
83 }
84 else if (input_dim != outputs[0]->getShape()[i])
85 {
86 printNodeValidationError("OpConcat: input dimension not matching output dimension");
87 return 1;
88 }
89 }
90 }
91
Kevin Cheng6e528662021-10-20 17:35:33 +000092 ERROR_IF(output_dim_on_axis != outputs[0]->getShape()[attribute->axis()],
Kevin Chengcc61be32021-10-14 17:09:57 -070093 "OpConcat: sum of input dimension on axis not equal to output dimension on axis");
94
95 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
96
Eric Kunzee5e26762020-10-13 16:11:07 -070097 return 0;
98}
99
100template <int Rank, DType Dtype>
101int OpConcat<Rank, Dtype>::eval()
102{
103
104 int32_t reversed_axis = Rank - 1 - attribute->axis();
105
106 for (int32_t d = 0; d < Rank; d++)
107 {
108 reverser[d] = Rank - 1 - d;
109 }
110
Kevin Chengad15dfa2021-03-04 15:15:03 -0800111 TIn result = ins[0]->getTensor().shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700112
Kevin Chengad15dfa2021-03-04 15:15:03 -0800113 for (size_t i = 1; i < ins.size(); i++)
114 {
115 TIn in_reversed = ins[i]->getTensor().shuffle(reverser);
116 TIn temp = result.concatenate(in_reversed, reversed_axis);
117 result = temp;
118 }
119 out->getTensor() = result.shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 return GraphNode::eval();
122}
123
124template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700125OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
126 TosaAttributeBase* attribute_,
127 TosaQuantInfoBase* qinfo_,
128 uint64_t id_)
129 : GraphNode(sgt_, Op_PAD, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700130{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000131 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700132 setRequiredRank(0, 6);
133
134 INIT_QINFO(Pad);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000135 INIT_ATTRIBUTE(Pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700136}
137
138template <int Rank, DType Dtype>
139OpPad<Rank, Dtype>::~OpPad()
140{
141 if (qinfo)
142 delete qinfo;
143}
144
145template <int Rank, DType Dtype>
146int OpPad<Rank, Dtype>::checkTensorAttributes()
147{
148 if (validateRequiredOperands())
149 return 1;
150
151 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
152 {
153 return 1;
154 }
155
156 // output and input must be the same types
157 if (inputs[0]->matchRankType(*outputs[0]))
158 {
159 printNodeValidationError("Failure to match input and output type and rank");
160 return 1;
161 }
162
Kevin Chengfe392ce2021-10-18 21:51:55 +0000163 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
164 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
165 ASSERT_MEM(in && out);
166
167 // padding in spec is 2D array in shape of [Rank, 2]
168 // Reference model implement this as 1D array of [Rank * 2], with ordering:
169 // [Rank0_front, Rank0_back, Rank1_front, Rank1_back, ..., Rank(N-1)_front, Rank(N-1)_back]
170 ERROR_IF(attribute->padding().size() != (Rank * 2), "OpPad: padding length needs to be (rank(input1) * 2)");
171
172 for (int i = 0; i < Rank; i++)
173 {
174 int32_t pad_front = attribute->padding()[2 * i];
175 int32_t pad_back = attribute->padding()[2 * i + 1];
176 ERROR_IF((pad_front < 0) || (pad_back < 0), "OpPad: padding can't be smaller than 0");
177 paddings_array[i] = std::make_pair(pad_front, pad_back);
178 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700179
Kevin Chengcc61be32021-10-14 17:09:57 -0700180 if (this->qinfo && Dtype != DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 {
Kevin Chengcc61be32021-10-14 17:09:57 -0700182 ERROR_IF(this->qinfo->input_zp() != 0, "OpPad: zeropoint should be 0");
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 }
184
185 return 0;
186}
187
188template <int Rank, DType Dtype>
189int OpPad<Rank, Dtype>::eval()
190{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000191 InEigenType pad_value = 0;
192
193 switch (Dtype)
Kevin Chengcc61be32021-10-14 17:09:57 -0700194 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000195 case DType_BOOL:
196 case DType_INT8:
197 case DType_INT16:
198 case DType_INT32:
199 pad_value = (InEigenType)attribute->pad_const_int();
200 break;
201 case DType_FLOAT:
202 pad_value = (InEigenType)attribute->pad_const_fp();
203 break;
Kevin Chengcc61be32021-10-14 17:09:57 -0700204 }
205
Kevin Chengfe392ce2021-10-18 21:51:55 +0000206 if (this->qinfo && Dtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000208 pad_value += (InEigenType)this->qinfo->input_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700209 }
210
211 this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
212
213 return GraphNode::eval();
214}
215
216template <int InRank, int OutRank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700217OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
218 TosaAttributeBase* attribute_,
219 TosaQuantInfoBase* qinfo_,
220 uint64_t id_)
221 : GraphNode(sgt_, Op_RESHAPE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700222{
223 setRequiredOperands(1, 1);
224 setRequiredRank(0, 6);
225
226 INIT_ATTRIBUTE(Reshape);
227}
228
229template <int InRank, int OutRank, DType Dtype>
230OpReshape<InRank, OutRank, Dtype>::~OpReshape()
231{
232 if (attribute)
233 delete attribute;
234}
235
236template <int InRank, int OutRank, DType Dtype>
237int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
238{
239 uint32_t minusOneCount = 0;
240
241 if (validateRequiredOperands())
242 return 1;
243
244 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
245 {
246 return 1;
247 }
248
249 // output and input must be the same types
250 if (inputs[0]->matchType(*outputs[0]))
251 {
252 printNodeValidationError("OpReshape: Input and output types must match");
253 return 1;
254 }
255
Kevin Chengcc61be32021-10-14 17:09:57 -0700256 ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
257 "Input tensor size does not match output tensor size");
258
Eric Kunzee5e26762020-10-13 16:11:07 -0700259 for (uint32_t d = 0; d < OutRank; d++)
260 {
261 if (attribute->shape()[d] == -1)
262 {
263 minusOneCount++;
264 }
Kevin Chengcc61be32021-10-14 17:09:57 -0700265 else
266 {
267 ERROR_IF(attribute->shape()[d] != outputs[0]->getShape()[d],
268 "OpReshape: new_shape doesn't match output shape");
269 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700270 }
271
272 if (minusOneCount > 1)
273 {
274 printNodeValidationError("OpReshape: new shape has more than one -1 dimension");
275 return 1;
276 }
277
278 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
279 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
280
281 return 0;
282}
283
284template <int InRank, int OutRank, DType Dtype>
285int OpReshape<InRank, OutRank, Dtype>::eval()
286{
287 uint32_t remainingSize = in->getElementCount();
288
289 // If there is a -1 dimension, find the remainder in one pass over the output shape
290 for (int32_t d = 0; d < OutRank; d++)
291 {
292 if (attribute->shape()[d] != -1)
293 {
294 remainingSize = remainingSize / attribute->shape()[d];
295 }
296 }
297
298 for (int32_t d = 0; d < OutRank; d++)
299 {
300 array_shape[d] = attribute->shape()[OutRank - 1 - d];
301 out_reverser[d] = OutRank - 1 - d;
302
303 // Jam in the remainder here
304 if (array_shape[d] == -1)
305 {
306 array_shape[d] = remainingSize;
307 }
308 }
309
310 for (int32_t d = 0; d < InRank; d++)
311 {
312 in_reverser[d] = InRank - 1 - d;
313 }
314
315 // Eigen Tensor is col-major, and we're referencing row-major result
316 // need to reverse it to row-major before reshape, and perform another reverse afterward
317
318 // input tensor rank 0 can't do .shuffle(), need to be handled otherwise
319 TIn in_reversed;
320 if (InRank > 1)
321 {
322 in_reversed = in->getTensor().shuffle(in_reverser);
323 }
324 else
325 {
326 in_reversed = in->getTensor();
327 }
328
329 TOut in_reshaped = in_reversed.reshape(array_shape);
330
331 // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise
332 if (OutRank > 1)
333 {
334 out->getTensor() = in_reshaped.shuffle(out_reverser);
335 }
336 else
337 {
338 out->getTensor() = in_reshaped;
339 }
340
341 return GraphNode::eval();
342}
343
344template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700345OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
346 TosaAttributeBase* attribute_,
347 TosaQuantInfoBase* qinfo_,
348 uint64_t id_)
349 : GraphNode(sgt_, Op_REVERSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700350{
351 setRequiredOperands(1, 1);
352 setRequiredRank(1, 6);
353
354 INIT_ATTRIBUTE(Axis);
355}
356
357template <int Rank, DType Dtype>
358OpReverse<Rank, Dtype>::~OpReverse()
359{
360 if (attribute)
361 delete attribute;
362}
363
364template <int Rank, DType Dtype>
365int OpReverse<Rank, Dtype>::checkTensorAttributes()
366{
367 if (validateRequiredOperands())
368 return 1;
369
370 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
371 {
372 return 1;
373 }
374
375 // output and input must be the same types
376 if (inputs[0]->matchRankTypeShape(*outputs[0]))
377 {
378 printNodeValidationError("Failure to match input and output rank/type/shape");
379 return 1;
380 }
381
382 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
383 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
384
385 ASSERT_MEM(in && out);
386
387 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
388 {
389 printNodeValidationError("Reverse axis must between [0, input_rank - 1]");
390 return 1;
391 }
392
393 // transform list of axis into true or false list
394 // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false]
395 for (int i = 0; i < Rank; i++)
396 {
397 reverse_array[i] = false;
398 }
399 reverse_array[attribute->axis()] = true;
400
401 return 0;
402}
403
404template <int Rank, DType Dtype>
405int OpReverse<Rank, Dtype>::eval()
406{
407 out->getTensor() = in->getTensor().reverse(reverse_array);
408
409 return GraphNode::eval();
410}
411
412template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700413OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
414 TosaAttributeBase* attribute_,
415 TosaQuantInfoBase* qinfo_,
416 uint64_t id_)
417 : GraphNode(sgt_, Op_SLICE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700418{
419 setRequiredOperands(1, 1);
Kevin Chengcc61be32021-10-14 17:09:57 -0700420 setRequiredRank(1, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700421
422 INIT_ATTRIBUTE(Slice);
423}
424
425template <int Rank, DType Dtype>
426OpSlice<Rank, Dtype>::~OpSlice()
427{
428 if (attribute)
429 delete attribute;
430}
431
432template <int Rank, DType Dtype>
433int OpSlice<Rank, Dtype>::checkTensorAttributes()
434{
435 if (validateRequiredOperands())
436 return 1;
437
438 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
439 {
440 return 1;
441 }
442
443 // output and input must be the same types
444 if (inputs[0]->matchType(*outputs[0]))
445 {
446 printNodeValidationError("Failure to match input and output type");
447 return 1;
448 }
449
450 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
451 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
452
Kevin Chengcc61be32021-10-14 17:09:57 -0700453 ERROR_IF((int32_t)attribute->begin().size() != in->getRank(),
454 "OpSlice: begin array length needs to be rank(input)");
455 ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
Eric Kunzee5e26762020-10-13 16:11:07 -0700456
Kevin Chengcc61be32021-10-14 17:09:57 -0700457 for (int32_t i = 0; i < in->getRank(); i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700458 {
Kevin Chengcc61be32021-10-14 17:09:57 -0700459 int32_t b = attribute->begin()[i];
460 int32_t s = attribute->size()[i];
461 ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary");
462 ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary");
463 ERROR_IF(s <= 0, "OpSlice: output must be positive");
464 ERROR_IF(s != out->getShape()[i], "OpSlice: size doesn't match output tensor dimension");
465 begin_array[i] = b;
466 size_array[i] = s;
Eric Kunzee5e26762020-10-13 16:11:07 -0700467 }
468
469 return 0;
470}
471
472template <int Rank, DType Dtype>
473int OpSlice<Rank, Dtype>::eval()
474{
475 out->getTensor() = in->getTensor().slice(begin_array, size_array);
476
477 return GraphNode::eval();
478}
479
480template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700481OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
482 TosaAttributeBase* attribute_,
483 TosaQuantInfoBase* qinfo_,
484 uint64_t id_)
485 : GraphNode(sgt_, Op_TILE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700486{
487 setRequiredOperands(1, 1);
488 setRequiredRank(0, 6);
489
490 INIT_ATTRIBUTE(Tile);
491}
492
493template <int Rank, DType Dtype>
494OpTileBase<Rank, Dtype>::~OpTileBase()
495{
496 if (attribute)
497 delete attribute;
498}
499
500template <int Rank, DType Dtype>
501int OpTileBase<Rank, Dtype>::checkTensorAttributes()
502{
503 if (validateRequiredOperands())
504 return 1;
505
506 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
507 {
508 return 1;
509 }
510
511 // output and input must be the same ranks and types
512 if (inputs[0]->matchRankType(*outputs[0]))
513 {
514 printNodeValidationError("Failure to match input and output rank or type");
515 return 1;
516 }
517
518 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
519 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
520
521 if (attribute->multiples().size() != Rank)
522 {
523 printNodeValidationError("1D list 'multiples' must have size equal to input rank");
524 return 1;
525 }
526
527 for (int32_t d = 0; d < Rank; d++)
528 {
529 if (in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d])
530 {
531 printNodeValidationError("unexpected output shape");
532 return 1;
533 }
534 }
535
536 return 0;
537}
538
539template <int Rank, DType Dtype>
540int OpTile<Rank, Dtype>::eval()
541{
542 // primary template shouldn't be called
Kevin Chengacb550f2021-06-29 15:32:19 -0700543 FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700544}
545
546template <DType Dtype>
547int OpTile<1, Dtype>::eval()
548{
549 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
550 {
551 int32_t id0 = od0 % this->in->getShape()[0];
552 this->out->getTensor()(od0) = this->in->getTensor()(id0);
553 }
554
555 return GraphNode::eval();
556}
557
558template <DType Dtype>
559int OpTile<2, Dtype>::eval()
560{
561 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
562 {
563 int32_t id0 = od0 % this->in->getShape()[0];
564 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
565 {
566 int32_t id1 = od1 % this->in->getShape()[1];
567 this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1);
568 }
569 }
570
571 return GraphNode::eval();
572}
573
574template <DType Dtype>
575int OpTile<3, Dtype>::eval()
576{
577 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
578 {
579 int32_t id0 = od0 % this->in->getShape()[0];
580 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
581 {
582 int32_t id1 = od1 % this->in->getShape()[1];
583 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
584 {
585 int32_t id2 = od2 % this->in->getShape()[2];
586 this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2);
587 }
588 }
589 }
590
591 return GraphNode::eval();
592}
593
594template <DType Dtype>
595int OpTile<4, Dtype>::eval()
596{
597 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
598 {
599 int32_t id0 = od0 % this->in->getShape()[0];
600 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
601 {
602 int32_t id1 = od1 % this->in->getShape()[1];
603 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
604 {
605 int32_t id2 = od2 % this->in->getShape()[2];
606 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
607 {
608 int32_t id3 = od3 % this->in->getShape()[3];
609 this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3);
610 }
611 }
612 }
613 }
614
615 return GraphNode::eval();
616}
617
618template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700619OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
620 TosaAttributeBase* attribute_,
621 TosaQuantInfoBase* qinfo_,
622 uint64_t id_)
623 : GraphNode(sgt_, Op_TRANSPOSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700624{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000625 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700626 setRequiredRank(0, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000627
628 INIT_ATTRIBUTE(Transpose);
Eric Kunzee5e26762020-10-13 16:11:07 -0700629}
630
631template <int Rank, DType Dtype>
632OpTranspose<Rank, Dtype>::~OpTranspose()
633{}
634
635template <int Rank, DType Dtype>
636int OpTranspose<Rank, Dtype>::checkTensorAttributes()
637{
638 if (validateRequiredOperands())
639 return 1;
640
641 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
642 {
643 return 1;
644 }
645
646 // output and input must be the same types
647 if (inputs[0]->matchRankType(*outputs[0]))
648 {
649 printNodeValidationError("Failure to match input and output rank and type");
650 return 1;
651 }
652
653 if (inputs[0]->getElementCount() != outputs[0]->getElementCount())
654 {
655 printNodeValidationError("Failure to match input and output total element count");
656 return 1;
657 }
658
Kevin Chengfe392ce2021-10-18 21:51:55 +0000659 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
660 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
661
662 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700663
664 return 0;
665}
666
667template <int Rank, DType Dtype>
668int OpTranspose<Rank, Dtype>::eval()
669{
670 for (int32_t d = 0; d < Rank; d++)
671 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000672 perm_array[d] = attribute->perm()[d];
Kevin Chengcc61be32021-10-14 17:09:57 -0700673 ERROR_IF(perm_array[d] < 0 or perm_array[d] >= Rank, "OpTranspose: index out of boundary");
Eric Kunzee5e26762020-10-13 16:11:07 -0700674 }
675
676 out->getTensor() = in->getTensor().shuffle(perm_array);
677
678 return GraphNode::eval();
679}
680
681// template explicit instantiation
682DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT)
Eric Kunzee5e26762020-10-13 16:11:07 -0700683DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
684DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
685DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
686DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
687
688DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700689DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
690DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
691DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
692DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
693
694DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700695DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
696DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
697DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
698DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
699
700DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700701DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
702DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
703DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
704DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
705
706DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700707DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
708DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
709DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
710DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
711
712DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700713DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
714DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
715DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
716DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
717
718DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700719DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
720DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
721DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
722DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);