blob: 1ed0be2de7c9870acdc113313aebad1718594201 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "data_layout.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
23template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070024OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
25 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -070026 uint64_t id_)
27 : GraphNode(sgt_, Op_CONCAT, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070028{
Kevin Chengad15dfa2021-03-04 15:15:03 -080029 setRequiredOperands(-1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070030 setRequiredRank(1, 6);
31
32 INIT_ATTRIBUTE(Axis);
33}
34
35template <int Rank, DType Dtype>
36OpConcat<Rank, Dtype>::~OpConcat()
37{
38 if (attribute)
39 delete attribute;
40}
41
42template <int Rank, DType Dtype>
43int OpConcat<Rank, Dtype>::checkTensorAttributes()
44{
45 if (validateRequiredOperands())
46 return 1;
47
Kevin Chengad15dfa2021-03-04 15:15:03 -080048 if (inputs.empty())
Eric Kunzee5e26762020-10-13 16:11:07 -070049 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080050 printNodeValidationError("Concat operator must have at least one input tensor");
Eric Kunzee5e26762020-10-13 16:11:07 -070051 return 1;
52 }
Kevin Chengcc61be32021-10-14 17:09:57 -070053
54 int32_t num_inputs = inputs.size();
55
Eric Kunzee5e26762020-10-13 16:11:07 -070056 // output and input must be the same types and rank
Kevin Chengcc61be32021-10-14 17:09:57 -070057 for (int32_t i = 0; i < num_inputs; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070058 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080059 if (inputs[i]->matchRankType(*outputs[0]))
60 {
Kevin Chengcc61be32021-10-14 17:09:57 -070061 printNodeValidationError("OpConcat: input ranks and types must match");
Kevin Chengad15dfa2021-03-04 15:15:03 -080062 return 1;
63 }
64 ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
Eric Kunzee5e26762020-10-13 16:11:07 -070065 }
66
Kevin Chengcc61be32021-10-14 17:09:57 -070067 if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
Eric Kunzee5e26762020-10-13 16:11:07 -070068 {
Kevin Chengcc61be32021-10-14 17:09:57 -070069 printNodeValidationError("OpConcat: axis is beyond output tensor rank");
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return 1;
71 }
72
Kevin Chengcc61be32021-10-14 17:09:57 -070073 int32_t output_dim_on_axis = 0;
74 for (int32_t j = 0; j < num_inputs; j++)
75 {
76 for (int32_t i = 0; i < Rank; i++)
77 {
78 int32_t input_dim = inputs[j]->getShape()[i];
79 if (i == attribute->axis())
80 {
81 output_dim_on_axis += input_dim;
82 }
83 else if (input_dim != outputs[0]->getShape()[i])
84 {
85 printNodeValidationError("OpConcat: input dimension not matching output dimension");
86 return 1;
87 }
88 }
89 }
90
Kevin Cheng6e528662021-10-20 17:35:33 +000091 ERROR_IF(output_dim_on_axis != outputs[0]->getShape()[attribute->axis()],
Kevin Chengcc61be32021-10-14 17:09:57 -070092 "OpConcat: sum of input dimension on axis not equal to output dimension on axis");
93
94 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
95
Eric Kunzee5e26762020-10-13 16:11:07 -070096 return 0;
97}
98
99template <int Rank, DType Dtype>
100int OpConcat<Rank, Dtype>::eval()
101{
102
103 int32_t reversed_axis = Rank - 1 - attribute->axis();
104
105 for (int32_t d = 0; d < Rank; d++)
106 {
107 reverser[d] = Rank - 1 - d;
108 }
109
Kevin Chengad15dfa2021-03-04 15:15:03 -0800110 TIn result = ins[0]->getTensor().shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700111
Kevin Chengad15dfa2021-03-04 15:15:03 -0800112 for (size_t i = 1; i < ins.size(); i++)
113 {
114 TIn in_reversed = ins[i]->getTensor().shuffle(reverser);
115 TIn temp = result.concatenate(in_reversed, reversed_axis);
116 result = temp;
117 }
118 out->getTensor() = result.shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700119
120 return GraphNode::eval();
121}
122
123template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700124OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
125 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700126 uint64_t id_)
127 : GraphNode(sgt_, Op_PAD, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700128{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000129 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700130 setRequiredRank(0, 6);
131
Kevin Chengfe392ce2021-10-18 21:51:55 +0000132 INIT_ATTRIBUTE(Pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700133}
134
135template <int Rank, DType Dtype>
136OpPad<Rank, Dtype>::~OpPad()
137{
Eric Kunzee5e26762020-10-13 16:11:07 -0700138}
139
140template <int Rank, DType Dtype>
141int OpPad<Rank, Dtype>::checkTensorAttributes()
142{
143 if (validateRequiredOperands())
144 return 1;
145
146 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
147 {
148 return 1;
149 }
150
151 // output and input must be the same types
152 if (inputs[0]->matchRankType(*outputs[0]))
153 {
154 printNodeValidationError("Failure to match input and output type and rank");
155 return 1;
156 }
157
Kevin Chengfe392ce2021-10-18 21:51:55 +0000158 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
159 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
160 ASSERT_MEM(in && out);
161
162 // padding in spec is 2D array in shape of [Rank, 2]
163 // Reference model implement this as 1D array of [Rank * 2], with ordering:
164 // [Rank0_front, Rank0_back, Rank1_front, Rank1_back, ..., Rank(N-1)_front, Rank(N-1)_back]
165 ERROR_IF(attribute->padding().size() != (Rank * 2), "OpPad: padding length needs to be (rank(input1) * 2)");
166
167 for (int i = 0; i < Rank; i++)
168 {
169 int32_t pad_front = attribute->padding()[2 * i];
170 int32_t pad_back = attribute->padding()[2 * i + 1];
171 ERROR_IF((pad_front < 0) || (pad_back < 0), "OpPad: padding can't be smaller than 0");
Eric Kunze3c59d5d2022-08-15 11:30:33 -0700172 ERROR_IF(out->getShape()[i] != pad_front + in->getShape()[i] + pad_back,
173 "OpPad: output shape not equal to input plus padding");
Kevin Chengfe392ce2021-10-18 21:51:55 +0000174 paddings_array[i] = std::make_pair(pad_front, pad_back);
175 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700176
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 return 0;
178}
179
180template <int Rank, DType Dtype>
181int OpPad<Rank, Dtype>::eval()
182{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000183 InEigenType pad_value = 0;
184
185 switch (Dtype)
Kevin Chengcc61be32021-10-14 17:09:57 -0700186 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000187 case DType_BOOL:
188 case DType_INT8:
189 case DType_INT16:
190 case DType_INT32:
191 pad_value = (InEigenType)attribute->pad_const_int();
192 break;
James Ward8b390432022-08-12 20:48:56 +0100193 case DType_FP16:
Kevin Chengfe392ce2021-10-18 21:51:55 +0000194 case DType_FLOAT:
195 pad_value = (InEigenType)attribute->pad_const_fp();
196 break;
TatWai Chong86c403b2022-06-06 20:46:01 -0700197 default:
198 printNodeValidationError("Unsupported data type");
199 break;
Kevin Chengcc61be32021-10-14 17:09:57 -0700200 }
201
Eric Kunzee5e26762020-10-13 16:11:07 -0700202 this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
203
204 return GraphNode::eval();
205}
206
207template <int InRank, int OutRank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700208OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
209 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700210 uint64_t id_)
211 : GraphNode(sgt_, Op_RESHAPE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700212{
213 setRequiredOperands(1, 1);
214 setRequiredRank(0, 6);
215
216 INIT_ATTRIBUTE(Reshape);
217}
218
219template <int InRank, int OutRank, DType Dtype>
220OpReshape<InRank, OutRank, Dtype>::~OpReshape()
221{
222 if (attribute)
223 delete attribute;
224}
225
226template <int InRank, int OutRank, DType Dtype>
227int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
228{
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 if (validateRequiredOperands())
230 return 1;
231
232 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
233 {
234 return 1;
235 }
236
237 // output and input must be the same types
238 if (inputs[0]->matchType(*outputs[0]))
239 {
240 printNodeValidationError("OpReshape: Input and output types must match");
241 return 1;
242 }
243
Kevin Chengcc61be32021-10-14 17:09:57 -0700244 ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
245 "Input tensor size does not match output tensor size");
246
Eric Kunzee5e26762020-10-13 16:11:07 -0700247 for (uint32_t d = 0; d < OutRank; d++)
248 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700249 ERROR_IF(attribute->new_shape()[d] != outputs[0]->getShape()[d],
Jeremy Johnsonc23fc3b2022-05-30 16:51:21 +0100250 "OpReshape: new_shape doesn't match output shape");
Eric Kunzee5e26762020-10-13 16:11:07 -0700251 }
252
253 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
254 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
255
256 return 0;
257}
258
259template <int InRank, int OutRank, DType Dtype>
260int OpReshape<InRank, OutRank, Dtype>::eval()
261{
Eric Kunzee5e26762020-10-13 16:11:07 -0700262 for (int32_t d = 0; d < OutRank; d++)
263 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700264 array_shape[d] = attribute->new_shape()[OutRank - 1 - d];
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 out_reverser[d] = OutRank - 1 - d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700266 }
267
268 for (int32_t d = 0; d < InRank; d++)
269 {
270 in_reverser[d] = InRank - 1 - d;
271 }
272
273 // Eigen Tensor is col-major, and we're referencing row-major result
274 // need to reverse it to row-major before reshape, and perform another reverse afterward
275
276 // input tensor rank 0 can't do .shuffle(), need to be handled otherwise
277 TIn in_reversed;
278 if (InRank > 1)
279 {
280 in_reversed = in->getTensor().shuffle(in_reverser);
281 }
282 else
283 {
284 in_reversed = in->getTensor();
285 }
286
287 TOut in_reshaped = in_reversed.reshape(array_shape);
288
289 // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise
290 if (OutRank > 1)
291 {
292 out->getTensor() = in_reshaped.shuffle(out_reverser);
293 }
294 else
295 {
296 out->getTensor() = in_reshaped;
297 }
298
299 return GraphNode::eval();
300}
301
302template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700303OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
304 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700305 uint64_t id_)
306 : GraphNode(sgt_, Op_REVERSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700307{
308 setRequiredOperands(1, 1);
309 setRequiredRank(1, 6);
310
311 INIT_ATTRIBUTE(Axis);
312}
313
314template <int Rank, DType Dtype>
315OpReverse<Rank, Dtype>::~OpReverse()
316{
317 if (attribute)
318 delete attribute;
319}
320
321template <int Rank, DType Dtype>
322int OpReverse<Rank, Dtype>::checkTensorAttributes()
323{
324 if (validateRequiredOperands())
325 return 1;
326
327 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
328 {
329 return 1;
330 }
331
332 // output and input must be the same types
333 if (inputs[0]->matchRankTypeShape(*outputs[0]))
334 {
335 printNodeValidationError("Failure to match input and output rank/type/shape");
336 return 1;
337 }
338
339 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
340 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
341
342 ASSERT_MEM(in && out);
343
344 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
345 {
346 printNodeValidationError("Reverse axis must between [0, input_rank - 1]");
347 return 1;
348 }
349
350 // transform list of axis into true or false list
351 // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false]
352 for (int i = 0; i < Rank; i++)
353 {
354 reverse_array[i] = false;
355 }
356 reverse_array[attribute->axis()] = true;
357
358 return 0;
359}
360
361template <int Rank, DType Dtype>
362int OpReverse<Rank, Dtype>::eval()
363{
364 out->getTensor() = in->getTensor().reverse(reverse_array);
365
366 return GraphNode::eval();
367}
368
369template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700370OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
371 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700372 uint64_t id_)
373 : GraphNode(sgt_, Op_SLICE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700374{
375 setRequiredOperands(1, 1);
Kevin Chengcc61be32021-10-14 17:09:57 -0700376 setRequiredRank(1, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700377
378 INIT_ATTRIBUTE(Slice);
379}
380
381template <int Rank, DType Dtype>
382OpSlice<Rank, Dtype>::~OpSlice()
383{
384 if (attribute)
385 delete attribute;
386}
387
388template <int Rank, DType Dtype>
389int OpSlice<Rank, Dtype>::checkTensorAttributes()
390{
391 if (validateRequiredOperands())
392 return 1;
393
394 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
395 {
396 return 1;
397 }
398
399 // output and input must be the same types
400 if (inputs[0]->matchType(*outputs[0]))
401 {
402 printNodeValidationError("Failure to match input and output type");
403 return 1;
404 }
405
406 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
407 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
408
TatWai Chong86c403b2022-06-06 20:46:01 -0700409 ERROR_IF((int32_t)attribute->start().size() != in->getRank(),
Kevin Chengcc61be32021-10-14 17:09:57 -0700410 "OpSlice: begin array length needs to be rank(input)");
411 ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
Eric Kunzee5e26762020-10-13 16:11:07 -0700412
Kevin Chengcc61be32021-10-14 17:09:57 -0700413 for (int32_t i = 0; i < in->getRank(); i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700414 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700415 int32_t b = attribute->start()[i];
Kevin Chengcc61be32021-10-14 17:09:57 -0700416 int32_t s = attribute->size()[i];
417 ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary");
418 ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary");
419 ERROR_IF(s <= 0, "OpSlice: output must be positive");
420 ERROR_IF(s != out->getShape()[i], "OpSlice: size doesn't match output tensor dimension");
421 begin_array[i] = b;
422 size_array[i] = s;
Eric Kunzee5e26762020-10-13 16:11:07 -0700423 }
424
425 return 0;
426}
427
428template <int Rank, DType Dtype>
429int OpSlice<Rank, Dtype>::eval()
430{
431 out->getTensor() = in->getTensor().slice(begin_array, size_array);
432
433 return GraphNode::eval();
434}
435
436template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700437OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
438 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700439 uint64_t id_)
440 : GraphNode(sgt_, Op_TILE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700441{
442 setRequiredOperands(1, 1);
443 setRequiredRank(0, 6);
444
445 INIT_ATTRIBUTE(Tile);
446}
447
448template <int Rank, DType Dtype>
449OpTileBase<Rank, Dtype>::~OpTileBase()
450{
451 if (attribute)
452 delete attribute;
453}
454
455template <int Rank, DType Dtype>
456int OpTileBase<Rank, Dtype>::checkTensorAttributes()
457{
458 if (validateRequiredOperands())
459 return 1;
460
461 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
462 {
463 return 1;
464 }
465
466 // output and input must be the same ranks and types
467 if (inputs[0]->matchRankType(*outputs[0]))
468 {
469 printNodeValidationError("Failure to match input and output rank or type");
470 return 1;
471 }
472
473 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
474 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
475
476 if (attribute->multiples().size() != Rank)
477 {
478 printNodeValidationError("1D list 'multiples' must have size equal to input rank");
479 return 1;
480 }
481
482 for (int32_t d = 0; d < Rank; d++)
483 {
Eric Kunzeec0327d2022-05-24 15:22:06 -0700484 ERROR_IF(in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d],
485 "Output shape not equal to input * multiples;")
Eric Kunzee5e26762020-10-13 16:11:07 -0700486 }
487
488 return 0;
489}
490
491template <int Rank, DType Dtype>
492int OpTile<Rank, Dtype>::eval()
493{
494 // primary template shouldn't be called
Kevin Chengacb550f2021-06-29 15:32:19 -0700495 FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700496}
497
498template <DType Dtype>
499int OpTile<1, Dtype>::eval()
500{
501 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
502 {
503 int32_t id0 = od0 % this->in->getShape()[0];
504 this->out->getTensor()(od0) = this->in->getTensor()(id0);
505 }
506
507 return GraphNode::eval();
508}
509
510template <DType Dtype>
511int OpTile<2, Dtype>::eval()
512{
513 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
514 {
515 int32_t id0 = od0 % this->in->getShape()[0];
516 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
517 {
518 int32_t id1 = od1 % this->in->getShape()[1];
519 this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1);
520 }
521 }
522
523 return GraphNode::eval();
524}
525
526template <DType Dtype>
527int OpTile<3, Dtype>::eval()
528{
529 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
530 {
531 int32_t id0 = od0 % this->in->getShape()[0];
532 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
533 {
534 int32_t id1 = od1 % this->in->getShape()[1];
535 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
536 {
537 int32_t id2 = od2 % this->in->getShape()[2];
538 this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2);
539 }
540 }
541 }
542
543 return GraphNode::eval();
544}
545
546template <DType Dtype>
547int OpTile<4, Dtype>::eval()
548{
549 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
550 {
551 int32_t id0 = od0 % this->in->getShape()[0];
552 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
553 {
554 int32_t id1 = od1 % this->in->getShape()[1];
555 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
556 {
557 int32_t id2 = od2 % this->in->getShape()[2];
558 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
559 {
560 int32_t id3 = od3 % this->in->getShape()[3];
561 this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3);
562 }
563 }
564 }
565 }
566
567 return GraphNode::eval();
568}
569
570template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700571OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
572 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700573 uint64_t id_)
574 : GraphNode(sgt_, Op_TRANSPOSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700575{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000576 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700577 setRequiredRank(0, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000578
579 INIT_ATTRIBUTE(Transpose);
Eric Kunzee5e26762020-10-13 16:11:07 -0700580}
581
582template <int Rank, DType Dtype>
583OpTranspose<Rank, Dtype>::~OpTranspose()
584{}
585
586template <int Rank, DType Dtype>
587int OpTranspose<Rank, Dtype>::checkTensorAttributes()
588{
589 if (validateRequiredOperands())
590 return 1;
591
592 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
593 {
594 return 1;
595 }
596
597 // output and input must be the same types
598 if (inputs[0]->matchRankType(*outputs[0]))
599 {
600 printNodeValidationError("Failure to match input and output rank and type");
601 return 1;
602 }
603
604 if (inputs[0]->getElementCount() != outputs[0]->getElementCount())
605 {
606 printNodeValidationError("Failure to match input and output total element count");
607 return 1;
608 }
609
Kevin Chengfe392ce2021-10-18 21:51:55 +0000610 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
611 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
612
613 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700614
TatWai Chong86c403b2022-06-06 20:46:01 -0700615 ERROR_IF(attribute->perms().size() != Rank, "OpTranspose: perms array size needs to match rank(input)");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000616
617 std::array<bool, Rank> index_used;
618 index_used.fill(false);
619 for (int32_t d = 0; d < Rank; d++)
620 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700621 int32_t index = attribute->perms()[d];
Kevin Chengf3e016f2021-11-02 01:15:50 +0000622 ERROR_IF(index < 0 or index >= Rank, "OpTranspose: index out of boundary");
623 ERROR_IF(index_used[index], "OpTranspose: index duplicated in perm attribute");
624 index_used[index] = true;
625 ERROR_IF(in->getShape()[index] != out->getShape()[d], "OpTranspose: input output shape mismatch");
626 perm_array[d] = index;
627 }
628
Eric Kunzee5e26762020-10-13 16:11:07 -0700629 return 0;
630}
631
632template <int Rank, DType Dtype>
633int OpTranspose<Rank, Dtype>::eval()
634{
Eric Kunzee5e26762020-10-13 16:11:07 -0700635 out->getTensor() = in->getTensor().shuffle(perm_array);
636
637 return GraphNode::eval();
638}
639
640// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +0100641DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16)
Eric Kunzee5e26762020-10-13 16:11:07 -0700642DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT)
Eric Kunzee5e26762020-10-13 16:11:07 -0700643DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
644DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
645DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
646DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
647
James Ward8b390432022-08-12 20:48:56 +0100648DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700649DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700650DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
651DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
652DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
653DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
654
James Ward8b390432022-08-12 20:48:56 +0100655DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700656DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700657DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
658DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
659DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
660DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
661
James Ward8b390432022-08-12 20:48:56 +0100662DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700663DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700664DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
665DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
666DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
667DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
668
James Ward8b390432022-08-12 20:48:56 +0100669DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700670DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700671DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
672DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
673DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
674DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
675
James Ward8b390432022-08-12 20:48:56 +0100676DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700677DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700678DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
679DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
680DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
681DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
682
James Ward8b390432022-08-12 20:48:56 +0100683DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700684DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700685DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
686DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
687DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
688DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);