blob: df7084de6b2cdc7aab94d46d09f1c9ebf81ff42f [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "data_layout.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
23template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070024OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
25 TosaAttributeBase* attribute_,
26 TosaQuantInfoBase* qinfo_,
27 uint64_t id_)
28 : GraphNode(sgt_, Op_CONCAT, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070029{
Kevin Chengad15dfa2021-03-04 15:15:03 -080030 setRequiredOperands(-1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070031 setRequiredRank(1, 6);
32
33 INIT_ATTRIBUTE(Axis);
34}
35
36template <int Rank, DType Dtype>
37OpConcat<Rank, Dtype>::~OpConcat()
38{
39 if (attribute)
40 delete attribute;
41}
42
43template <int Rank, DType Dtype>
44int OpConcat<Rank, Dtype>::checkTensorAttributes()
45{
46 if (validateRequiredOperands())
47 return 1;
48
Kevin Chengad15dfa2021-03-04 15:15:03 -080049 if (inputs.empty())
Eric Kunzee5e26762020-10-13 16:11:07 -070050 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080051 printNodeValidationError("Concat operator must have at least one input tensor");
Eric Kunzee5e26762020-10-13 16:11:07 -070052 return 1;
53 }
Kevin Chengcc61be32021-10-14 17:09:57 -070054
55 int32_t num_inputs = inputs.size();
56
Eric Kunzee5e26762020-10-13 16:11:07 -070057 // output and input must be the same types and rank
Kevin Chengcc61be32021-10-14 17:09:57 -070058 for (int32_t i = 0; i < num_inputs; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070059 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080060 if (inputs[i]->matchRankType(*outputs[0]))
61 {
Kevin Chengcc61be32021-10-14 17:09:57 -070062 printNodeValidationError("OpConcat: input ranks and types must match");
Kevin Chengad15dfa2021-03-04 15:15:03 -080063 return 1;
64 }
65 ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
Eric Kunzee5e26762020-10-13 16:11:07 -070066 }
67
Kevin Chengcc61be32021-10-14 17:09:57 -070068 if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
Eric Kunzee5e26762020-10-13 16:11:07 -070069 {
Kevin Chengcc61be32021-10-14 17:09:57 -070070 printNodeValidationError("OpConcat: axis is beyond output tensor rank");
Eric Kunzee5e26762020-10-13 16:11:07 -070071 return 1;
72 }
73
Kevin Chengcc61be32021-10-14 17:09:57 -070074 int32_t output_dim_on_axis = 0;
75 for (int32_t j = 0; j < num_inputs; j++)
76 {
77 for (int32_t i = 0; i < Rank; i++)
78 {
79 int32_t input_dim = inputs[j]->getShape()[i];
80 if (i == attribute->axis())
81 {
82 output_dim_on_axis += input_dim;
83 }
84 else if (input_dim != outputs[0]->getShape()[i])
85 {
86 printNodeValidationError("OpConcat: input dimension not matching output dimension");
87 return 1;
88 }
89 }
90 }
91
Kevin Cheng6e528662021-10-20 17:35:33 +000092 ERROR_IF(output_dim_on_axis != outputs[0]->getShape()[attribute->axis()],
Kevin Chengcc61be32021-10-14 17:09:57 -070093 "OpConcat: sum of input dimension on axis not equal to output dimension on axis");
94
95 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
96
Eric Kunzee5e26762020-10-13 16:11:07 -070097 return 0;
98}
99
100template <int Rank, DType Dtype>
101int OpConcat<Rank, Dtype>::eval()
102{
103
104 int32_t reversed_axis = Rank - 1 - attribute->axis();
105
106 for (int32_t d = 0; d < Rank; d++)
107 {
108 reverser[d] = Rank - 1 - d;
109 }
110
Kevin Chengad15dfa2021-03-04 15:15:03 -0800111 TIn result = ins[0]->getTensor().shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700112
Kevin Chengad15dfa2021-03-04 15:15:03 -0800113 for (size_t i = 1; i < ins.size(); i++)
114 {
115 TIn in_reversed = ins[i]->getTensor().shuffle(reverser);
116 TIn temp = result.concatenate(in_reversed, reversed_axis);
117 result = temp;
118 }
119 out->getTensor() = result.shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 return GraphNode::eval();
122}
123
124template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700125OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
126 TosaAttributeBase* attribute_,
127 TosaQuantInfoBase* qinfo_,
128 uint64_t id_)
129 : GraphNode(sgt_, Op_PAD, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700130{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000131 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700132 setRequiredRank(0, 6);
133
134 INIT_QINFO(Pad);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000135 INIT_ATTRIBUTE(Pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700136}
137
138template <int Rank, DType Dtype>
139OpPad<Rank, Dtype>::~OpPad()
140{
141 if (qinfo)
142 delete qinfo;
143}
144
145template <int Rank, DType Dtype>
146int OpPad<Rank, Dtype>::checkTensorAttributes()
147{
148 if (validateRequiredOperands())
149 return 1;
150
151 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
152 {
153 return 1;
154 }
155
156 // output and input must be the same types
157 if (inputs[0]->matchRankType(*outputs[0]))
158 {
159 printNodeValidationError("Failure to match input and output type and rank");
160 return 1;
161 }
162
Kevin Chengfe392ce2021-10-18 21:51:55 +0000163 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
164 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
165 ASSERT_MEM(in && out);
166
167 // padding in spec is 2D array in shape of [Rank, 2]
168 // Reference model implement this as 1D array of [Rank * 2], with ordering:
169 // [Rank0_front, Rank0_back, Rank1_front, Rank1_back, ..., Rank(N-1)_front, Rank(N-1)_back]
170 ERROR_IF(attribute->padding().size() != (Rank * 2), "OpPad: padding length needs to be (rank(input1) * 2)");
171
172 for (int i = 0; i < Rank; i++)
173 {
174 int32_t pad_front = attribute->padding()[2 * i];
175 int32_t pad_back = attribute->padding()[2 * i + 1];
176 ERROR_IF((pad_front < 0) || (pad_back < 0), "OpPad: padding can't be smaller than 0");
177 paddings_array[i] = std::make_pair(pad_front, pad_back);
178 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700179
Kevin Chengcc61be32021-10-14 17:09:57 -0700180 if (this->qinfo && Dtype != DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 {
Kevin Chengcc61be32021-10-14 17:09:57 -0700182 ERROR_IF(this->qinfo->input_zp() != 0, "OpPad: zeropoint should be 0");
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 }
184
185 return 0;
186}
187
188template <int Rank, DType Dtype>
189int OpPad<Rank, Dtype>::eval()
190{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000191 InEigenType pad_value = 0;
192
193 switch (Dtype)
Kevin Chengcc61be32021-10-14 17:09:57 -0700194 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000195 case DType_BOOL:
196 case DType_INT8:
197 case DType_INT16:
198 case DType_INT32:
199 pad_value = (InEigenType)attribute->pad_const_int();
200 break;
201 case DType_FLOAT:
202 pad_value = (InEigenType)attribute->pad_const_fp();
203 break;
TatWai Chong86c403b2022-06-06 20:46:01 -0700204 default:
205 printNodeValidationError("Unsupported data type");
206 break;
Kevin Chengcc61be32021-10-14 17:09:57 -0700207 }
208
Kevin Chengfe392ce2021-10-18 21:51:55 +0000209 if (this->qinfo && Dtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700210 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000211 pad_value += (InEigenType)this->qinfo->input_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 }
213
214 this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
215
216 return GraphNode::eval();
217}
218
219template <int InRank, int OutRank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700220OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
221 TosaAttributeBase* attribute_,
222 TosaQuantInfoBase* qinfo_,
223 uint64_t id_)
224 : GraphNode(sgt_, Op_RESHAPE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700225{
226 setRequiredOperands(1, 1);
227 setRequiredRank(0, 6);
228
229 INIT_ATTRIBUTE(Reshape);
230}
231
232template <int InRank, int OutRank, DType Dtype>
233OpReshape<InRank, OutRank, Dtype>::~OpReshape()
234{
235 if (attribute)
236 delete attribute;
237}
238
239template <int InRank, int OutRank, DType Dtype>
240int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
241{
Eric Kunzee5e26762020-10-13 16:11:07 -0700242 if (validateRequiredOperands())
243 return 1;
244
245 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
246 {
247 return 1;
248 }
249
250 // output and input must be the same types
251 if (inputs[0]->matchType(*outputs[0]))
252 {
253 printNodeValidationError("OpReshape: Input and output types must match");
254 return 1;
255 }
256
Kevin Chengcc61be32021-10-14 17:09:57 -0700257 ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
258 "Input tensor size does not match output tensor size");
259
Eric Kunzee5e26762020-10-13 16:11:07 -0700260 for (uint32_t d = 0; d < OutRank; d++)
261 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700262 ERROR_IF(attribute->new_shape()[d] != outputs[0]->getShape()[d],
Jeremy Johnsonc23fc3b2022-05-30 16:51:21 +0100263 "OpReshape: new_shape doesn't match output shape");
Eric Kunzee5e26762020-10-13 16:11:07 -0700264 }
265
266 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
267 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
268
269 return 0;
270}
271
272template <int InRank, int OutRank, DType Dtype>
273int OpReshape<InRank, OutRank, Dtype>::eval()
274{
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 for (int32_t d = 0; d < OutRank; d++)
276 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700277 array_shape[d] = attribute->new_shape()[OutRank - 1 - d];
Eric Kunzee5e26762020-10-13 16:11:07 -0700278 out_reverser[d] = OutRank - 1 - d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700279 }
280
281 for (int32_t d = 0; d < InRank; d++)
282 {
283 in_reverser[d] = InRank - 1 - d;
284 }
285
286 // Eigen Tensor is col-major, and we're referencing row-major result
287 // need to reverse it to row-major before reshape, and perform another reverse afterward
288
289 // input tensor rank 0 can't do .shuffle(), need to be handled otherwise
290 TIn in_reversed;
291 if (InRank > 1)
292 {
293 in_reversed = in->getTensor().shuffle(in_reverser);
294 }
295 else
296 {
297 in_reversed = in->getTensor();
298 }
299
300 TOut in_reshaped = in_reversed.reshape(array_shape);
301
302 // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise
303 if (OutRank > 1)
304 {
305 out->getTensor() = in_reshaped.shuffle(out_reverser);
306 }
307 else
308 {
309 out->getTensor() = in_reshaped;
310 }
311
312 return GraphNode::eval();
313}
314
315template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700316OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
317 TosaAttributeBase* attribute_,
318 TosaQuantInfoBase* qinfo_,
319 uint64_t id_)
320 : GraphNode(sgt_, Op_REVERSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700321{
322 setRequiredOperands(1, 1);
323 setRequiredRank(1, 6);
324
325 INIT_ATTRIBUTE(Axis);
326}
327
328template <int Rank, DType Dtype>
329OpReverse<Rank, Dtype>::~OpReverse()
330{
331 if (attribute)
332 delete attribute;
333}
334
335template <int Rank, DType Dtype>
336int OpReverse<Rank, Dtype>::checkTensorAttributes()
337{
338 if (validateRequiredOperands())
339 return 1;
340
341 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
342 {
343 return 1;
344 }
345
346 // output and input must be the same types
347 if (inputs[0]->matchRankTypeShape(*outputs[0]))
348 {
349 printNodeValidationError("Failure to match input and output rank/type/shape");
350 return 1;
351 }
352
353 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
354 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
355
356 ASSERT_MEM(in && out);
357
358 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
359 {
360 printNodeValidationError("Reverse axis must between [0, input_rank - 1]");
361 return 1;
362 }
363
364 // transform list of axis into true or false list
365 // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false]
366 for (int i = 0; i < Rank; i++)
367 {
368 reverse_array[i] = false;
369 }
370 reverse_array[attribute->axis()] = true;
371
372 return 0;
373}
374
375template <int Rank, DType Dtype>
376int OpReverse<Rank, Dtype>::eval()
377{
378 out->getTensor() = in->getTensor().reverse(reverse_array);
379
380 return GraphNode::eval();
381}
382
383template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700384OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
385 TosaAttributeBase* attribute_,
386 TosaQuantInfoBase* qinfo_,
387 uint64_t id_)
388 : GraphNode(sgt_, Op_SLICE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700389{
390 setRequiredOperands(1, 1);
Kevin Chengcc61be32021-10-14 17:09:57 -0700391 setRequiredRank(1, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700392
393 INIT_ATTRIBUTE(Slice);
394}
395
396template <int Rank, DType Dtype>
397OpSlice<Rank, Dtype>::~OpSlice()
398{
399 if (attribute)
400 delete attribute;
401}
402
403template <int Rank, DType Dtype>
404int OpSlice<Rank, Dtype>::checkTensorAttributes()
405{
406 if (validateRequiredOperands())
407 return 1;
408
409 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
410 {
411 return 1;
412 }
413
414 // output and input must be the same types
415 if (inputs[0]->matchType(*outputs[0]))
416 {
417 printNodeValidationError("Failure to match input and output type");
418 return 1;
419 }
420
421 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
422 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
423
TatWai Chong86c403b2022-06-06 20:46:01 -0700424 ERROR_IF((int32_t)attribute->start().size() != in->getRank(),
Kevin Chengcc61be32021-10-14 17:09:57 -0700425 "OpSlice: begin array length needs to be rank(input)");
426 ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
Eric Kunzee5e26762020-10-13 16:11:07 -0700427
Kevin Chengcc61be32021-10-14 17:09:57 -0700428 for (int32_t i = 0; i < in->getRank(); i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700429 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700430 int32_t b = attribute->start()[i];
Kevin Chengcc61be32021-10-14 17:09:57 -0700431 int32_t s = attribute->size()[i];
432 ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary");
433 ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary");
434 ERROR_IF(s <= 0, "OpSlice: output must be positive");
435 ERROR_IF(s != out->getShape()[i], "OpSlice: size doesn't match output tensor dimension");
436 begin_array[i] = b;
437 size_array[i] = s;
Eric Kunzee5e26762020-10-13 16:11:07 -0700438 }
439
440 return 0;
441}
442
443template <int Rank, DType Dtype>
444int OpSlice<Rank, Dtype>::eval()
445{
446 out->getTensor() = in->getTensor().slice(begin_array, size_array);
447
448 return GraphNode::eval();
449}
450
451template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700452OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
453 TosaAttributeBase* attribute_,
454 TosaQuantInfoBase* qinfo_,
455 uint64_t id_)
456 : GraphNode(sgt_, Op_TILE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700457{
458 setRequiredOperands(1, 1);
459 setRequiredRank(0, 6);
460
461 INIT_ATTRIBUTE(Tile);
462}
463
464template <int Rank, DType Dtype>
465OpTileBase<Rank, Dtype>::~OpTileBase()
466{
467 if (attribute)
468 delete attribute;
469}
470
471template <int Rank, DType Dtype>
472int OpTileBase<Rank, Dtype>::checkTensorAttributes()
473{
474 if (validateRequiredOperands())
475 return 1;
476
477 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
478 {
479 return 1;
480 }
481
482 // output and input must be the same ranks and types
483 if (inputs[0]->matchRankType(*outputs[0]))
484 {
485 printNodeValidationError("Failure to match input and output rank or type");
486 return 1;
487 }
488
489 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
490 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
491
492 if (attribute->multiples().size() != Rank)
493 {
494 printNodeValidationError("1D list 'multiples' must have size equal to input rank");
495 return 1;
496 }
497
498 for (int32_t d = 0; d < Rank; d++)
499 {
500 if (in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d])
501 {
502 printNodeValidationError("unexpected output shape");
503 return 1;
504 }
505 }
506
507 return 0;
508}
509
510template <int Rank, DType Dtype>
511int OpTile<Rank, Dtype>::eval()
512{
513 // primary template shouldn't be called
Kevin Chengacb550f2021-06-29 15:32:19 -0700514 FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700515}
516
517template <DType Dtype>
518int OpTile<1, Dtype>::eval()
519{
520 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
521 {
522 int32_t id0 = od0 % this->in->getShape()[0];
523 this->out->getTensor()(od0) = this->in->getTensor()(id0);
524 }
525
526 return GraphNode::eval();
527}
528
529template <DType Dtype>
530int OpTile<2, Dtype>::eval()
531{
532 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
533 {
534 int32_t id0 = od0 % this->in->getShape()[0];
535 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
536 {
537 int32_t id1 = od1 % this->in->getShape()[1];
538 this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1);
539 }
540 }
541
542 return GraphNode::eval();
543}
544
545template <DType Dtype>
546int OpTile<3, Dtype>::eval()
547{
548 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
549 {
550 int32_t id0 = od0 % this->in->getShape()[0];
551 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
552 {
553 int32_t id1 = od1 % this->in->getShape()[1];
554 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
555 {
556 int32_t id2 = od2 % this->in->getShape()[2];
557 this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2);
558 }
559 }
560 }
561
562 return GraphNode::eval();
563}
564
565template <DType Dtype>
566int OpTile<4, Dtype>::eval()
567{
568 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
569 {
570 int32_t id0 = od0 % this->in->getShape()[0];
571 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
572 {
573 int32_t id1 = od1 % this->in->getShape()[1];
574 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
575 {
576 int32_t id2 = od2 % this->in->getShape()[2];
577 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
578 {
579 int32_t id3 = od3 % this->in->getShape()[3];
580 this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3);
581 }
582 }
583 }
584 }
585
586 return GraphNode::eval();
587}
588
589template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700590OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
591 TosaAttributeBase* attribute_,
592 TosaQuantInfoBase* qinfo_,
593 uint64_t id_)
594 : GraphNode(sgt_, Op_TRANSPOSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700595{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000596 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700597 setRequiredRank(0, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000598
599 INIT_ATTRIBUTE(Transpose);
Eric Kunzee5e26762020-10-13 16:11:07 -0700600}
601
602template <int Rank, DType Dtype>
603OpTranspose<Rank, Dtype>::~OpTranspose()
604{}
605
606template <int Rank, DType Dtype>
607int OpTranspose<Rank, Dtype>::checkTensorAttributes()
608{
609 if (validateRequiredOperands())
610 return 1;
611
612 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
613 {
614 return 1;
615 }
616
617 // output and input must be the same types
618 if (inputs[0]->matchRankType(*outputs[0]))
619 {
620 printNodeValidationError("Failure to match input and output rank and type");
621 return 1;
622 }
623
624 if (inputs[0]->getElementCount() != outputs[0]->getElementCount())
625 {
626 printNodeValidationError("Failure to match input and output total element count");
627 return 1;
628 }
629
Kevin Chengfe392ce2021-10-18 21:51:55 +0000630 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
631 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
632
633 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700634
TatWai Chong86c403b2022-06-06 20:46:01 -0700635 ERROR_IF(attribute->perms().size() != Rank, "OpTranspose: perms array size needs to match rank(input)");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000636
637 std::array<bool, Rank> index_used;
638 index_used.fill(false);
639 for (int32_t d = 0; d < Rank; d++)
640 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700641 int32_t index = attribute->perms()[d];
Kevin Chengf3e016f2021-11-02 01:15:50 +0000642 ERROR_IF(index < 0 or index >= Rank, "OpTranspose: index out of boundary");
643 ERROR_IF(index_used[index], "OpTranspose: index duplicated in perm attribute");
644 index_used[index] = true;
645 ERROR_IF(in->getShape()[index] != out->getShape()[d], "OpTranspose: input output shape mismatch");
646 perm_array[d] = index;
647 }
648
Eric Kunzee5e26762020-10-13 16:11:07 -0700649 return 0;
650}
651
652template <int Rank, DType Dtype>
653int OpTranspose<Rank, Dtype>::eval()
654{
Eric Kunzee5e26762020-10-13 16:11:07 -0700655 out->getTensor() = in->getTensor().shuffle(perm_array);
656
657 return GraphNode::eval();
658}
659
660// template explicit instantiation
661DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT)
Eric Kunzee5e26762020-10-13 16:11:07 -0700662DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
663DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
664DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
665DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
666
667DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700668DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
669DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
670DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
671DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
672
673DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700674DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
675DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
676DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
677DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
678
679DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700680DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
681DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
682DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
683DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
684
685DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700686DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
687DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
688DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
689DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
690
691DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700692DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
693DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
694DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
695DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
696
697DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700698DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
699DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
700DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
701DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);