blob: ce5b5afd6d0b903a4fcb9bf27e3ebc2ba40eb4e4 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Huttona4e48ca2023-02-22 11:53:48 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "data_layout.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
23template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070024OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
25 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -070026 uint64_t id_)
27 : GraphNode(sgt_, Op_CONCAT, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070028{
Kevin Chengad15dfa2021-03-04 15:15:03 -080029 setRequiredOperands(-1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070030 setRequiredRank(1, 6);
31
32 INIT_ATTRIBUTE(Axis);
33}
34
35template <int Rank, DType Dtype>
36OpConcat<Rank, Dtype>::~OpConcat()
37{
38 if (attribute)
39 delete attribute;
40}
41
42template <int Rank, DType Dtype>
43int OpConcat<Rank, Dtype>::checkTensorAttributes()
44{
45 if (validateRequiredOperands())
46 return 1;
47
Kevin Chengad15dfa2021-03-04 15:15:03 -080048 if (inputs.empty())
Eric Kunzee5e26762020-10-13 16:11:07 -070049 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080050 printNodeValidationError("Concat operator must have at least one input tensor");
Eric Kunzee5e26762020-10-13 16:11:07 -070051 return 1;
52 }
Kevin Chengcc61be32021-10-14 17:09:57 -070053
54 int32_t num_inputs = inputs.size();
55
Eric Kunzee5e26762020-10-13 16:11:07 -070056 // output and input must be the same types and rank
Kevin Chengcc61be32021-10-14 17:09:57 -070057 for (int32_t i = 0; i < num_inputs; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070058 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080059 if (inputs[i]->matchRankType(*outputs[0]))
60 {
Kevin Chengcc61be32021-10-14 17:09:57 -070061 printNodeValidationError("OpConcat: input ranks and types must match");
Kevin Chengad15dfa2021-03-04 15:15:03 -080062 return 1;
63 }
64 ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
Eric Kunzee5e26762020-10-13 16:11:07 -070065 }
66
Kevin Chengcc61be32021-10-14 17:09:57 -070067 if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
Eric Kunzee5e26762020-10-13 16:11:07 -070068 {
Kevin Chengcc61be32021-10-14 17:09:57 -070069 printNodeValidationError("OpConcat: axis is beyond output tensor rank");
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return 1;
71 }
72
Kevin Chengcc61be32021-10-14 17:09:57 -070073 int32_t output_dim_on_axis = 0;
74 for (int32_t j = 0; j < num_inputs; j++)
75 {
76 for (int32_t i = 0; i < Rank; i++)
77 {
78 int32_t input_dim = inputs[j]->getShape()[i];
79 if (i == attribute->axis())
80 {
81 output_dim_on_axis += input_dim;
82 }
83 else if (input_dim != outputs[0]->getShape()[i])
84 {
85 printNodeValidationError("OpConcat: input dimension not matching output dimension");
86 return 1;
87 }
88 }
89 }
90
Kevin Cheng6e528662021-10-20 17:35:33 +000091 ERROR_IF(output_dim_on_axis != outputs[0]->getShape()[attribute->axis()],
Kevin Chengcc61be32021-10-14 17:09:57 -070092 "OpConcat: sum of input dimension on axis not equal to output dimension on axis");
93
94 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
95
Eric Kunzee5e26762020-10-13 16:11:07 -070096 return 0;
97}
98
99template <int Rank, DType Dtype>
100int OpConcat<Rank, Dtype>::eval()
101{
102
103 int32_t reversed_axis = Rank - 1 - attribute->axis();
104
105 for (int32_t d = 0; d < Rank; d++)
106 {
107 reverser[d] = Rank - 1 - d;
108 }
109
Kevin Chengad15dfa2021-03-04 15:15:03 -0800110 TIn result = ins[0]->getTensor().shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700111
Kevin Chengad15dfa2021-03-04 15:15:03 -0800112 for (size_t i = 1; i < ins.size(); i++)
113 {
114 TIn in_reversed = ins[i]->getTensor().shuffle(reverser);
115 TIn temp = result.concatenate(in_reversed, reversed_axis);
116 result = temp;
117 }
118 out->getTensor() = result.shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700119
120 return GraphNode::eval();
121}
122
123template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700124OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
125 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700126 uint64_t id_)
127 : GraphNode(sgt_, Op_PAD, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700128{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000129 setRequiredOperands(1, 1);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000130 setRequiredRank(1, 6);
Eric Kunzee5e26762020-10-13 16:11:07 -0700131
Kevin Chengfe392ce2021-10-18 21:51:55 +0000132 INIT_ATTRIBUTE(Pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700133}
134
135template <int Rank, DType Dtype>
136OpPad<Rank, Dtype>::~OpPad()
137{
Eric Kunzee5e26762020-10-13 16:11:07 -0700138}
139
140template <int Rank, DType Dtype>
141int OpPad<Rank, Dtype>::checkTensorAttributes()
142{
143 if (validateRequiredOperands())
144 return 1;
145
146 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
147 {
148 return 1;
149 }
150
151 // output and input must be the same types
152 if (inputs[0]->matchRankType(*outputs[0]))
153 {
154 printNodeValidationError("Failure to match input and output type and rank");
155 return 1;
156 }
157
Kevin Chengfe392ce2021-10-18 21:51:55 +0000158 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
159 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
160 ASSERT_MEM(in && out);
161
162 // padding in spec is 2D array in shape of [Rank, 2]
163 // Reference model implement this as 1D array of [Rank * 2], with ordering:
164 // [Rank0_front, Rank0_back, Rank1_front, Rank1_back, ..., Rank(N-1)_front, Rank(N-1)_back]
165 ERROR_IF(attribute->padding().size() != (Rank * 2), "OpPad: padding length needs to be (rank(input1) * 2)");
166
167 for (int i = 0; i < Rank; i++)
168 {
169 int32_t pad_front = attribute->padding()[2 * i];
170 int32_t pad_back = attribute->padding()[2 * i + 1];
171 ERROR_IF((pad_front < 0) || (pad_back < 0), "OpPad: padding can't be smaller than 0");
Eric Kunze3c59d5d2022-08-15 11:30:33 -0700172 ERROR_IF(out->getShape()[i] != pad_front + in->getShape()[i] + pad_back,
173 "OpPad: output shape not equal to input plus padding");
Kevin Chengfe392ce2021-10-18 21:51:55 +0000174 paddings_array[i] = std::make_pair(pad_front, pad_back);
175 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700176
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 return 0;
178}
179
180template <int Rank, DType Dtype>
181int OpPad<Rank, Dtype>::eval()
182{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000183 InEigenType pad_value = 0;
184
185 switch (Dtype)
Kevin Chengcc61be32021-10-14 17:09:57 -0700186 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000187 case DType_BOOL:
188 case DType_INT8:
189 case DType_INT16:
190 case DType_INT32:
191 pad_value = (InEigenType)attribute->pad_const_int();
192 break;
James Ward8b390432022-08-12 20:48:56 +0100193 case DType_FP16:
James Ward34071252022-12-07 15:48:47 +0000194 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100195 case DType_FP32:
Kevin Chengfe392ce2021-10-18 21:51:55 +0000196 pad_value = (InEigenType)attribute->pad_const_fp();
197 break;
TatWai Chong86c403b2022-06-06 20:46:01 -0700198 default:
199 printNodeValidationError("Unsupported data type");
200 break;
Kevin Chengcc61be32021-10-14 17:09:57 -0700201 }
202
Eric Kunzee5e26762020-10-13 16:11:07 -0700203 this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
204
205 return GraphNode::eval();
206}
207
208template <int InRank, int OutRank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700209OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
210 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700211 uint64_t id_)
212 : GraphNode(sgt_, Op_RESHAPE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700213{
214 setRequiredOperands(1, 1);
215 setRequiredRank(0, 6);
216
217 INIT_ATTRIBUTE(Reshape);
218}
219
220template <int InRank, int OutRank, DType Dtype>
221OpReshape<InRank, OutRank, Dtype>::~OpReshape()
222{
223 if (attribute)
224 delete attribute;
225}
226
227template <int InRank, int OutRank, DType Dtype>
228int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
229{
Eric Kunzee5e26762020-10-13 16:11:07 -0700230 if (validateRequiredOperands())
231 return 1;
232
233 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
234 {
235 return 1;
236 }
237
238 // output and input must be the same types
239 if (inputs[0]->matchType(*outputs[0]))
240 {
241 printNodeValidationError("OpReshape: Input and output types must match");
242 return 1;
243 }
244
Kevin Chengcc61be32021-10-14 17:09:57 -0700245 ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
246 "Input tensor size does not match output tensor size");
247
Eric Kunzee5e26762020-10-13 16:11:07 -0700248 for (uint32_t d = 0; d < OutRank; d++)
249 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700250 ERROR_IF(attribute->new_shape()[d] != outputs[0]->getShape()[d],
Jeremy Johnsonc23fc3b2022-05-30 16:51:21 +0100251 "OpReshape: new_shape doesn't match output shape");
Eric Kunzee5e26762020-10-13 16:11:07 -0700252 }
253
254 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
255 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
256
257 return 0;
258}
259
260template <int InRank, int OutRank, DType Dtype>
261int OpReshape<InRank, OutRank, Dtype>::eval()
262{
Eric Kunzee5e26762020-10-13 16:11:07 -0700263 for (int32_t d = 0; d < OutRank; d++)
264 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700265 array_shape[d] = attribute->new_shape()[OutRank - 1 - d];
Eric Kunzee5e26762020-10-13 16:11:07 -0700266 out_reverser[d] = OutRank - 1 - d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700267 }
268
269 for (int32_t d = 0; d < InRank; d++)
270 {
271 in_reverser[d] = InRank - 1 - d;
272 }
273
274 // Eigen Tensor is col-major, and we're referencing row-major result
275 // need to reverse it to row-major before reshape, and perform another reverse afterward
276
277 // input tensor rank 0 can't do .shuffle(), need to be handled otherwise
278 TIn in_reversed;
279 if (InRank > 1)
280 {
281 in_reversed = in->getTensor().shuffle(in_reverser);
282 }
283 else
284 {
285 in_reversed = in->getTensor();
286 }
287
288 TOut in_reshaped = in_reversed.reshape(array_shape);
289
290 // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise
291 if (OutRank > 1)
292 {
293 out->getTensor() = in_reshaped.shuffle(out_reverser);
294 }
295 else
296 {
297 out->getTensor() = in_reshaped;
298 }
299
300 return GraphNode::eval();
301}
302
303template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700304OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
305 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700306 uint64_t id_)
307 : GraphNode(sgt_, Op_REVERSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700308{
309 setRequiredOperands(1, 1);
310 setRequiredRank(1, 6);
311
312 INIT_ATTRIBUTE(Axis);
313}
314
315template <int Rank, DType Dtype>
316OpReverse<Rank, Dtype>::~OpReverse()
317{
318 if (attribute)
319 delete attribute;
320}
321
322template <int Rank, DType Dtype>
323int OpReverse<Rank, Dtype>::checkTensorAttributes()
324{
325 if (validateRequiredOperands())
326 return 1;
327
328 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
329 {
330 return 1;
331 }
332
333 // output and input must be the same types
334 if (inputs[0]->matchRankTypeShape(*outputs[0]))
335 {
336 printNodeValidationError("Failure to match input and output rank/type/shape");
337 return 1;
338 }
339
340 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
341 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
342
343 ASSERT_MEM(in && out);
344
345 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
346 {
347 printNodeValidationError("Reverse axis must between [0, input_rank - 1]");
348 return 1;
349 }
350
351 // transform list of axis into true or false list
352 // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false]
353 for (int i = 0; i < Rank; i++)
354 {
355 reverse_array[i] = false;
356 }
357 reverse_array[attribute->axis()] = true;
358
359 return 0;
360}
361
362template <int Rank, DType Dtype>
363int OpReverse<Rank, Dtype>::eval()
364{
365 out->getTensor() = in->getTensor().reverse(reverse_array);
366
367 return GraphNode::eval();
368}
369
370template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700371OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
372 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700373 uint64_t id_)
374 : GraphNode(sgt_, Op_SLICE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700375{
376 setRequiredOperands(1, 1);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000377 setRequiredRank(1, 6);
Eric Kunzee5e26762020-10-13 16:11:07 -0700378
379 INIT_ATTRIBUTE(Slice);
380}
381
382template <int Rank, DType Dtype>
383OpSlice<Rank, Dtype>::~OpSlice()
384{
385 if (attribute)
386 delete attribute;
387}
388
389template <int Rank, DType Dtype>
390int OpSlice<Rank, Dtype>::checkTensorAttributes()
391{
392 if (validateRequiredOperands())
393 return 1;
394
395 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
396 {
397 return 1;
398 }
399
400 // output and input must be the same types
Luke Huttona4e48ca2023-02-22 11:53:48 +0000401 if (inputs[0]->matchRankType(*outputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700402 {
Luke Huttona4e48ca2023-02-22 11:53:48 +0000403 printNodeValidationError("Failure to match input and output rank or type");
Eric Kunzee5e26762020-10-13 16:11:07 -0700404 return 1;
405 }
406
407 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
408 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
409
Luke Huttona4e48ca2023-02-22 11:53:48 +0000410 ASSERT_MEM(in && out);
411
TatWai Chong86c403b2022-06-06 20:46:01 -0700412 ERROR_IF((int32_t)attribute->start().size() != in->getRank(),
Kevin Chengcc61be32021-10-14 17:09:57 -0700413 "OpSlice: begin array length needs to be rank(input)");
414 ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
Eric Kunzee5e26762020-10-13 16:11:07 -0700415
Kevin Chengcc61be32021-10-14 17:09:57 -0700416 for (int32_t i = 0; i < in->getRank(); i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700417 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700418 int32_t b = attribute->start()[i];
Kevin Chengcc61be32021-10-14 17:09:57 -0700419 int32_t s = attribute->size()[i];
420 ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary");
421 ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary");
422 ERROR_IF(s <= 0, "OpSlice: output must be positive");
423 ERROR_IF(s != out->getShape()[i], "OpSlice: size doesn't match output tensor dimension");
424 begin_array[i] = b;
425 size_array[i] = s;
Eric Kunzee5e26762020-10-13 16:11:07 -0700426 }
427
428 return 0;
429}
430
431template <int Rank, DType Dtype>
432int OpSlice<Rank, Dtype>::eval()
433{
434 out->getTensor() = in->getTensor().slice(begin_array, size_array);
435
436 return GraphNode::eval();
437}
438
439template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700440OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
441 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700442 uint64_t id_)
443 : GraphNode(sgt_, Op_TILE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700444{
445 setRequiredOperands(1, 1);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000446 setRequiredRank(1, 6);
Eric Kunzee5e26762020-10-13 16:11:07 -0700447
448 INIT_ATTRIBUTE(Tile);
449}
450
451template <int Rank, DType Dtype>
452OpTileBase<Rank, Dtype>::~OpTileBase()
453{
454 if (attribute)
455 delete attribute;
456}
457
458template <int Rank, DType Dtype>
459int OpTileBase<Rank, Dtype>::checkTensorAttributes()
460{
461 if (validateRequiredOperands())
462 return 1;
463
464 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
465 {
466 return 1;
467 }
468
469 // output and input must be the same ranks and types
470 if (inputs[0]->matchRankType(*outputs[0]))
471 {
472 printNodeValidationError("Failure to match input and output rank or type");
473 return 1;
474 }
475
476 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
477 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
478
Luke Huttona4e48ca2023-02-22 11:53:48 +0000479 ASSERT_MEM(in && out);
480
Eric Kunzee5e26762020-10-13 16:11:07 -0700481 if (attribute->multiples().size() != Rank)
482 {
483 printNodeValidationError("1D list 'multiples' must have size equal to input rank");
484 return 1;
485 }
486
487 for (int32_t d = 0; d < Rank; d++)
488 {
Eric Kunzeec0327d2022-05-24 15:22:06 -0700489 ERROR_IF(in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d],
490 "Output shape not equal to input * multiples;")
Eric Kunzee5e26762020-10-13 16:11:07 -0700491 }
492
493 return 0;
494}
495
496template <int Rank, DType Dtype>
497int OpTile<Rank, Dtype>::eval()
498{
499 // primary template shouldn't be called
Kevin Chengacb550f2021-06-29 15:32:19 -0700500 FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700501}
502
503template <DType Dtype>
504int OpTile<1, Dtype>::eval()
505{
506 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
507 {
508 int32_t id0 = od0 % this->in->getShape()[0];
509 this->out->getTensor()(od0) = this->in->getTensor()(id0);
510 }
511
512 return GraphNode::eval();
513}
514
515template <DType Dtype>
516int OpTile<2, Dtype>::eval()
517{
518 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
519 {
520 int32_t id0 = od0 % this->in->getShape()[0];
521 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
522 {
523 int32_t id1 = od1 % this->in->getShape()[1];
524 this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1);
525 }
526 }
527
528 return GraphNode::eval();
529}
530
531template <DType Dtype>
532int OpTile<3, Dtype>::eval()
533{
534 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
535 {
536 int32_t id0 = od0 % this->in->getShape()[0];
537 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
538 {
539 int32_t id1 = od1 % this->in->getShape()[1];
540 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
541 {
542 int32_t id2 = od2 % this->in->getShape()[2];
543 this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2);
544 }
545 }
546 }
547
548 return GraphNode::eval();
549}
550
551template <DType Dtype>
552int OpTile<4, Dtype>::eval()
553{
554 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
555 {
556 int32_t id0 = od0 % this->in->getShape()[0];
557 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
558 {
559 int32_t id1 = od1 % this->in->getShape()[1];
560 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
561 {
562 int32_t id2 = od2 % this->in->getShape()[2];
563 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
564 {
565 int32_t id3 = od3 % this->in->getShape()[3];
566 this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3);
567 }
568 }
569 }
570 }
571
572 return GraphNode::eval();
573}
574
Luke Huttona4e48ca2023-02-22 11:53:48 +0000575template <DType Dtype>
576int OpTile<5, Dtype>::eval()
577{
578 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
579 {
580 int32_t id0 = od0 % this->in->getShape()[0];
581 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
582 {
583 int32_t id1 = od1 % this->in->getShape()[1];
584 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
585 {
586 int32_t id2 = od2 % this->in->getShape()[2];
587 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
588 {
589 int32_t id3 = od3 % this->in->getShape()[3];
590 for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
591 {
592 int32_t id4 = od4 % this->in->getShape()[4];
593 this->out->getTensor()(od0, od1, od2, od3, od4) =
594 this->in->getTensor()(id0, id1, id2, id3, id4);
595 }
596 }
597 }
598 }
599 }
600
601 return GraphNode::eval();
602}
603
604template <DType Dtype>
605int OpTile<6, Dtype>::eval()
606{
607 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
608 {
609 int32_t id0 = od0 % this->in->getShape()[0];
610 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
611 {
612 int32_t id1 = od1 % this->in->getShape()[1];
613 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
614 {
615 int32_t id2 = od2 % this->in->getShape()[2];
616 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
617 {
618 int32_t id3 = od3 % this->in->getShape()[3];
619 for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
620 {
621 int32_t id4 = od4 % this->in->getShape()[4];
622 for (int32_t od5 = 0; od5 < this->out->getShape()[5]; od5++)
623 {
624 int32_t id5 = od5 % this->in->getShape()[5];
625 this->out->getTensor()(od0, od1, od2, od3, od4, od5) =
626 this->in->getTensor()(id0, id1, id2, id3, id4, id5);
627 }
628 }
629 }
630 }
631 }
632 }
633
634 return GraphNode::eval();
635}
636
Eric Kunzee5e26762020-10-13 16:11:07 -0700637template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700638OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
639 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700640 uint64_t id_)
641 : GraphNode(sgt_, Op_TRANSPOSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700642{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000643 setRequiredOperands(1, 1);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000644 setRequiredRank(1, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000645
646 INIT_ATTRIBUTE(Transpose);
Eric Kunzee5e26762020-10-13 16:11:07 -0700647}
648
649template <int Rank, DType Dtype>
650OpTranspose<Rank, Dtype>::~OpTranspose()
Jerry Gea6827492022-11-16 10:41:55 -0800651{
652 if (attribute) delete attribute;
653}
Eric Kunzee5e26762020-10-13 16:11:07 -0700654
655template <int Rank, DType Dtype>
656int OpTranspose<Rank, Dtype>::checkTensorAttributes()
657{
658 if (validateRequiredOperands())
659 return 1;
660
661 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
662 {
663 return 1;
664 }
665
666 // output and input must be the same types
667 if (inputs[0]->matchRankType(*outputs[0]))
668 {
669 printNodeValidationError("Failure to match input and output rank and type");
670 return 1;
671 }
672
673 if (inputs[0]->getElementCount() != outputs[0]->getElementCount())
674 {
675 printNodeValidationError("Failure to match input and output total element count");
676 return 1;
677 }
678
Kevin Chengfe392ce2021-10-18 21:51:55 +0000679 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
680 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
681
682 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700683
TatWai Chong86c403b2022-06-06 20:46:01 -0700684 ERROR_IF(attribute->perms().size() != Rank, "OpTranspose: perms array size needs to match rank(input)");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000685
686 std::array<bool, Rank> index_used;
687 index_used.fill(false);
688 for (int32_t d = 0; d < Rank; d++)
689 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700690 int32_t index = attribute->perms()[d];
Kevin Chengf3e016f2021-11-02 01:15:50 +0000691 ERROR_IF(index < 0 or index >= Rank, "OpTranspose: index out of boundary");
692 ERROR_IF(index_used[index], "OpTranspose: index duplicated in perm attribute");
693 index_used[index] = true;
694 ERROR_IF(in->getShape()[index] != out->getShape()[d], "OpTranspose: input output shape mismatch");
695 perm_array[d] = index;
696 }
697
Eric Kunzee5e26762020-10-13 16:11:07 -0700698 return 0;
699}
700
701template <int Rank, DType Dtype>
702int OpTranspose<Rank, Dtype>::eval()
703{
Eric Kunzee5e26762020-10-13 16:11:07 -0700704 out->getTensor() = in->getTensor().shuffle(perm_array);
705
706 return GraphNode::eval();
707}
708
709// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +0100710DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16)
James Ward24dbc422022-10-19 12:20:31 +0100711DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100712DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700713DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
714DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
715DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
716DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
717
James Ward8b390432022-08-12 20:48:56 +0100718DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100719DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100720DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700721DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
722DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
723DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
724DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
725
James Ward8b390432022-08-12 20:48:56 +0100726DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100727DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100728DEF_INSTANTIATE_RESHAPE(OpReshape, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700729DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
730DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
731DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
732DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
733
James Ward8b390432022-08-12 20:48:56 +0100734DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100735DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100736DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700737DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
738DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
739DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
740DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
741
Luke Huttona4e48ca2023-02-22 11:53:48 +0000742DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
743DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
744DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
745DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
746DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
747DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
748DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
Eric Kunzee5e26762020-10-13 16:11:07 -0700749
Luke Huttona4e48ca2023-02-22 11:53:48 +0000750DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
751DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
752DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32);
753DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
754DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
755DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
756DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
Jared Smolens98c281f2022-12-20 15:09:25 -0800757
Luke Huttona4e48ca2023-02-22 11:53:48 +0000758DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
759DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
760DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
761DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
762DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
763DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
764DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
Eric Kunzee5e26762020-10-13 16:11:07 -0700765
Luke Huttona4e48ca2023-02-22 11:53:48 +0000766DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
767DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
768DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
769DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
770DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
771DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
772DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);