blob: a18946610a6b3b66abab7607c50a398f98a5b704 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Huttona4e48ca2023-02-22 11:53:48 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "data_layout.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
23template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070024OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
25 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -070026 uint64_t id_)
27 : GraphNode(sgt_, Op_CONCAT, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070028{
Kevin Chengad15dfa2021-03-04 15:15:03 -080029 setRequiredOperands(-1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070030 setRequiredRank(1, 6);
31
32 INIT_ATTRIBUTE(Axis);
33}
34
35template <int Rank, DType Dtype>
36OpConcat<Rank, Dtype>::~OpConcat()
37{
38 if (attribute)
39 delete attribute;
40}
41
42template <int Rank, DType Dtype>
43int OpConcat<Rank, Dtype>::checkTensorAttributes()
44{
Jerry Gea793f462023-04-11 00:05:02 +000045 // Check Tosa Level
46 auto tosa_level = g_func_config.tosa_level;
47 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
48
Eric Kunzee5e26762020-10-13 16:11:07 -070049 if (validateRequiredOperands())
50 return 1;
51
Kevin Chengad15dfa2021-03-04 15:15:03 -080052 if (inputs.empty())
Eric Kunzee5e26762020-10-13 16:11:07 -070053 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080054 printNodeValidationError("Concat operator must have at least one input tensor");
Eric Kunzee5e26762020-10-13 16:11:07 -070055 return 1;
56 }
Kevin Chengcc61be32021-10-14 17:09:57 -070057
58 int32_t num_inputs = inputs.size();
59
Eric Kunzee5e26762020-10-13 16:11:07 -070060 // output and input must be the same types and rank
Kevin Chengcc61be32021-10-14 17:09:57 -070061 for (int32_t i = 0; i < num_inputs; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070062 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080063 if (inputs[i]->matchRankType(*outputs[0]))
64 {
Kevin Chengcc61be32021-10-14 17:09:57 -070065 printNodeValidationError("OpConcat: input ranks and types must match");
Kevin Chengad15dfa2021-03-04 15:15:03 -080066 return 1;
67 }
68 ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
Eric Kunzee5e26762020-10-13 16:11:07 -070069 }
70
Kevin Chengcc61be32021-10-14 17:09:57 -070071 if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
Eric Kunzee5e26762020-10-13 16:11:07 -070072 {
Kevin Chengcc61be32021-10-14 17:09:57 -070073 printNodeValidationError("OpConcat: axis is beyond output tensor rank");
Eric Kunzee5e26762020-10-13 16:11:07 -070074 return 1;
75 }
76
Kevin Chengcc61be32021-10-14 17:09:57 -070077 int32_t output_dim_on_axis = 0;
78 for (int32_t j = 0; j < num_inputs; j++)
79 {
80 for (int32_t i = 0; i < Rank; i++)
81 {
82 int32_t input_dim = inputs[j]->getShape()[i];
83 if (i == attribute->axis())
84 {
85 output_dim_on_axis += input_dim;
86 }
87 else if (input_dim != outputs[0]->getShape()[i])
88 {
89 printNodeValidationError("OpConcat: input dimension not matching output dimension");
90 return 1;
91 }
92 }
93 }
94
Kevin Cheng6e528662021-10-20 17:35:33 +000095 ERROR_IF(output_dim_on_axis != outputs[0]->getShape()[attribute->axis()],
Kevin Chengcc61be32021-10-14 17:09:57 -070096 "OpConcat: sum of input dimension on axis not equal to output dimension on axis");
97
98 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
99
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 return 0;
101}
102
103template <int Rank, DType Dtype>
104int OpConcat<Rank, Dtype>::eval()
105{
106
107 int32_t reversed_axis = Rank - 1 - attribute->axis();
108
109 for (int32_t d = 0; d < Rank; d++)
110 {
111 reverser[d] = Rank - 1 - d;
112 }
113
Kevin Chengad15dfa2021-03-04 15:15:03 -0800114 TIn result = ins[0]->getTensor().shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700115
Kevin Chengad15dfa2021-03-04 15:15:03 -0800116 for (size_t i = 1; i < ins.size(); i++)
117 {
118 TIn in_reversed = ins[i]->getTensor().shuffle(reverser);
119 TIn temp = result.concatenate(in_reversed, reversed_axis);
120 result = temp;
121 }
122 out->getTensor() = result.shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700123
124 return GraphNode::eval();
125}
126
127template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700128OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
129 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700130 uint64_t id_)
131 : GraphNode(sgt_, Op_PAD, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700132{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000133 setRequiredOperands(1, 1);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000134 setRequiredRank(1, 6);
Eric Kunzee5e26762020-10-13 16:11:07 -0700135
Kevin Chengfe392ce2021-10-18 21:51:55 +0000136 INIT_ATTRIBUTE(Pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700137}
138
139template <int Rank, DType Dtype>
140OpPad<Rank, Dtype>::~OpPad()
141{
Eric Kunzee5e26762020-10-13 16:11:07 -0700142}
143
144template <int Rank, DType Dtype>
145int OpPad<Rank, Dtype>::checkTensorAttributes()
146{
Jerry Gea793f462023-04-11 00:05:02 +0000147 // Check Tosa Level
148 auto tosa_level = g_func_config.tosa_level;
149 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
150
Eric Kunzee5e26762020-10-13 16:11:07 -0700151 if (validateRequiredOperands())
152 return 1;
153
154 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
155 {
156 return 1;
157 }
158
159 // output and input must be the same types
160 if (inputs[0]->matchRankType(*outputs[0]))
161 {
162 printNodeValidationError("Failure to match input and output type and rank");
163 return 1;
164 }
165
Kevin Chengfe392ce2021-10-18 21:51:55 +0000166 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
167 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
168 ASSERT_MEM(in && out);
169
170 // padding in spec is 2D array in shape of [Rank, 2]
171 // Reference model implement this as 1D array of [Rank * 2], with ordering:
172 // [Rank0_front, Rank0_back, Rank1_front, Rank1_back, ..., Rank(N-1)_front, Rank(N-1)_back]
173 ERROR_IF(attribute->padding().size() != (Rank * 2), "OpPad: padding length needs to be (rank(input1) * 2)");
174
175 for (int i = 0; i < Rank; i++)
176 {
177 int32_t pad_front = attribute->padding()[2 * i];
178 int32_t pad_back = attribute->padding()[2 * i + 1];
179 ERROR_IF((pad_front < 0) || (pad_back < 0), "OpPad: padding can't be smaller than 0");
Eric Kunze3c59d5d2022-08-15 11:30:33 -0700180 ERROR_IF(out->getShape()[i] != pad_front + in->getShape()[i] + pad_back,
181 "OpPad: output shape not equal to input plus padding");
Kevin Chengfe392ce2021-10-18 21:51:55 +0000182 paddings_array[i] = std::make_pair(pad_front, pad_back);
183 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700184
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 return 0;
186}
187
188template <int Rank, DType Dtype>
189int OpPad<Rank, Dtype>::eval()
190{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000191 InEigenType pad_value = 0;
192
193 switch (Dtype)
Kevin Chengcc61be32021-10-14 17:09:57 -0700194 {
Kevin Chengfe392ce2021-10-18 21:51:55 +0000195 case DType_BOOL:
196 case DType_INT8:
197 case DType_INT16:
198 case DType_INT32:
199 pad_value = (InEigenType)attribute->pad_const_int();
200 break;
James Ward8b390432022-08-12 20:48:56 +0100201 case DType_FP16:
James Ward34071252022-12-07 15:48:47 +0000202 case DType_BF16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100203 case DType_FP32:
Kevin Chengfe392ce2021-10-18 21:51:55 +0000204 pad_value = (InEigenType)attribute->pad_const_fp();
205 break;
TatWai Chong86c403b2022-06-06 20:46:01 -0700206 default:
207 printNodeValidationError("Unsupported data type");
208 break;
Kevin Chengcc61be32021-10-14 17:09:57 -0700209 }
210
Eric Kunzee5e26762020-10-13 16:11:07 -0700211 this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
212
213 return GraphNode::eval();
214}
215
216template <int InRank, int OutRank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700217OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
218 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700219 uint64_t id_)
220 : GraphNode(sgt_, Op_RESHAPE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700221{
222 setRequiredOperands(1, 1);
223 setRequiredRank(0, 6);
224
225 INIT_ATTRIBUTE(Reshape);
226}
227
228template <int InRank, int OutRank, DType Dtype>
229OpReshape<InRank, OutRank, Dtype>::~OpReshape()
230{
231 if (attribute)
232 delete attribute;
233}
234
235template <int InRank, int OutRank, DType Dtype>
236int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
237{
Jerry Gea793f462023-04-11 00:05:02 +0000238 // Check Tosa Level
239 auto tosa_level = g_func_config.tosa_level;
240 LEVEL_CHECK(InRank <= tosa_level.MAX_RANK, "InRank should be smaller than or equal to MAX_RANK");
241 LEVEL_CHECK(OutRank <= tosa_level.MAX_RANK, "OutRank should be smaller than or equal to MAX_RANK");
242
Eric Kunzee5e26762020-10-13 16:11:07 -0700243 if (validateRequiredOperands())
244 return 1;
245
246 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
247 {
248 return 1;
249 }
250
251 // output and input must be the same types
252 if (inputs[0]->matchType(*outputs[0]))
253 {
254 printNodeValidationError("OpReshape: Input and output types must match");
255 return 1;
256 }
257
Kevin Chengcc61be32021-10-14 17:09:57 -0700258 ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
259 "Input tensor size does not match output tensor size");
260
Eric Kunzee5e26762020-10-13 16:11:07 -0700261 for (uint32_t d = 0; d < OutRank; d++)
262 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700263 ERROR_IF(attribute->new_shape()[d] != outputs[0]->getShape()[d],
Jeremy Johnsonc23fc3b2022-05-30 16:51:21 +0100264 "OpReshape: new_shape doesn't match output shape");
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 }
266
267 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
268 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
269
270 return 0;
271}
272
273template <int InRank, int OutRank, DType Dtype>
274int OpReshape<InRank, OutRank, Dtype>::eval()
275{
Eric Kunzee5e26762020-10-13 16:11:07 -0700276 for (int32_t d = 0; d < OutRank; d++)
277 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700278 array_shape[d] = attribute->new_shape()[OutRank - 1 - d];
Eric Kunzee5e26762020-10-13 16:11:07 -0700279 out_reverser[d] = OutRank - 1 - d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700280 }
281
282 for (int32_t d = 0; d < InRank; d++)
283 {
284 in_reverser[d] = InRank - 1 - d;
285 }
286
287 // Eigen Tensor is col-major, and we're referencing row-major result
288 // need to reverse it to row-major before reshape, and perform another reverse afterward
289
290 // input tensor rank 0 can't do .shuffle(), need to be handled otherwise
291 TIn in_reversed;
292 if (InRank > 1)
293 {
294 in_reversed = in->getTensor().shuffle(in_reverser);
295 }
296 else
297 {
298 in_reversed = in->getTensor();
299 }
300
301 TOut in_reshaped = in_reversed.reshape(array_shape);
302
303 // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise
304 if (OutRank > 1)
305 {
306 out->getTensor() = in_reshaped.shuffle(out_reverser);
307 }
308 else
309 {
310 out->getTensor() = in_reshaped;
311 }
312
313 return GraphNode::eval();
314}
315
316template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700317OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
318 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700319 uint64_t id_)
320 : GraphNode(sgt_, Op_REVERSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700321{
322 setRequiredOperands(1, 1);
323 setRequiredRank(1, 6);
324
325 INIT_ATTRIBUTE(Axis);
326}
327
328template <int Rank, DType Dtype>
329OpReverse<Rank, Dtype>::~OpReverse()
330{
331 if (attribute)
332 delete attribute;
333}
334
335template <int Rank, DType Dtype>
336int OpReverse<Rank, Dtype>::checkTensorAttributes()
337{
Jerry Gea793f462023-04-11 00:05:02 +0000338 // Check Tosa Level
339 auto tosa_level = g_func_config.tosa_level;
340 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
341
Eric Kunzee5e26762020-10-13 16:11:07 -0700342 if (validateRequiredOperands())
343 return 1;
344
345 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
346 {
347 return 1;
348 }
349
350 // output and input must be the same types
351 if (inputs[0]->matchRankTypeShape(*outputs[0]))
352 {
353 printNodeValidationError("Failure to match input and output rank/type/shape");
354 return 1;
355 }
356
357 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
358 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
359
360 ASSERT_MEM(in && out);
361
362 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
363 {
364 printNodeValidationError("Reverse axis must between [0, input_rank - 1]");
365 return 1;
366 }
367
368 // transform list of axis into true or false list
369 // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false]
370 for (int i = 0; i < Rank; i++)
371 {
372 reverse_array[i] = false;
373 }
374 reverse_array[attribute->axis()] = true;
375
376 return 0;
377}
378
379template <int Rank, DType Dtype>
380int OpReverse<Rank, Dtype>::eval()
381{
382 out->getTensor() = in->getTensor().reverse(reverse_array);
383
384 return GraphNode::eval();
385}
386
387template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700388OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
389 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700390 uint64_t id_)
391 : GraphNode(sgt_, Op_SLICE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700392{
393 setRequiredOperands(1, 1);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000394 setRequiredRank(1, 6);
Eric Kunzee5e26762020-10-13 16:11:07 -0700395
396 INIT_ATTRIBUTE(Slice);
397}
398
399template <int Rank, DType Dtype>
400OpSlice<Rank, Dtype>::~OpSlice()
401{
402 if (attribute)
403 delete attribute;
404}
405
406template <int Rank, DType Dtype>
407int OpSlice<Rank, Dtype>::checkTensorAttributes()
408{
Jerry Gea793f462023-04-11 00:05:02 +0000409 // Check Tosa Level
410 auto tosa_level = g_func_config.tosa_level;
411 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
412
Eric Kunzee5e26762020-10-13 16:11:07 -0700413 if (validateRequiredOperands())
414 return 1;
415
416 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
417 {
418 return 1;
419 }
420
421 // output and input must be the same types
Luke Huttona4e48ca2023-02-22 11:53:48 +0000422 if (inputs[0]->matchRankType(*outputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700423 {
Luke Huttona4e48ca2023-02-22 11:53:48 +0000424 printNodeValidationError("Failure to match input and output rank or type");
Eric Kunzee5e26762020-10-13 16:11:07 -0700425 return 1;
426 }
427
428 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
429 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
430
Luke Huttona4e48ca2023-02-22 11:53:48 +0000431 ASSERT_MEM(in && out);
432
TatWai Chong86c403b2022-06-06 20:46:01 -0700433 ERROR_IF((int32_t)attribute->start().size() != in->getRank(),
Kevin Chengcc61be32021-10-14 17:09:57 -0700434 "OpSlice: begin array length needs to be rank(input)");
435 ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
Eric Kunzee5e26762020-10-13 16:11:07 -0700436
Kevin Chengcc61be32021-10-14 17:09:57 -0700437 for (int32_t i = 0; i < in->getRank(); i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700438 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700439 int32_t b = attribute->start()[i];
Kevin Chengcc61be32021-10-14 17:09:57 -0700440 int32_t s = attribute->size()[i];
441 ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary");
442 ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary");
443 ERROR_IF(s <= 0, "OpSlice: output must be positive");
444 ERROR_IF(s != out->getShape()[i], "OpSlice: size doesn't match output tensor dimension");
445 begin_array[i] = b;
446 size_array[i] = s;
Eric Kunzee5e26762020-10-13 16:11:07 -0700447 }
448
449 return 0;
450}
451
452template <int Rank, DType Dtype>
453int OpSlice<Rank, Dtype>::eval()
454{
455 out->getTensor() = in->getTensor().slice(begin_array, size_array);
456
457 return GraphNode::eval();
458}
459
460template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700461OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
462 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700463 uint64_t id_)
464 : GraphNode(sgt_, Op_TILE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700465{
466 setRequiredOperands(1, 1);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000467 setRequiredRank(1, 6);
Eric Kunzee5e26762020-10-13 16:11:07 -0700468
469 INIT_ATTRIBUTE(Tile);
470}
471
472template <int Rank, DType Dtype>
473OpTileBase<Rank, Dtype>::~OpTileBase()
474{
475 if (attribute)
476 delete attribute;
477}
478
479template <int Rank, DType Dtype>
480int OpTileBase<Rank, Dtype>::checkTensorAttributes()
481{
Jerry Gea793f462023-04-11 00:05:02 +0000482 // Check Tosa Level
483 auto tosa_level = g_func_config.tosa_level;
484 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
485
Eric Kunzee5e26762020-10-13 16:11:07 -0700486 if (validateRequiredOperands())
487 return 1;
488
489 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
490 {
491 return 1;
492 }
493
494 // output and input must be the same ranks and types
495 if (inputs[0]->matchRankType(*outputs[0]))
496 {
497 printNodeValidationError("Failure to match input and output rank or type");
498 return 1;
499 }
500
501 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
502 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
503
Luke Huttona4e48ca2023-02-22 11:53:48 +0000504 ASSERT_MEM(in && out);
505
Eric Kunzee5e26762020-10-13 16:11:07 -0700506 if (attribute->multiples().size() != Rank)
507 {
508 printNodeValidationError("1D list 'multiples' must have size equal to input rank");
509 return 1;
510 }
511
512 for (int32_t d = 0; d < Rank; d++)
513 {
Eric Kunzeec0327d2022-05-24 15:22:06 -0700514 ERROR_IF(in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d],
515 "Output shape not equal to input * multiples;")
Eric Kunzee5e26762020-10-13 16:11:07 -0700516 }
517
518 return 0;
519}
520
521template <int Rank, DType Dtype>
522int OpTile<Rank, Dtype>::eval()
523{
524 // primary template shouldn't be called
Kevin Chengacb550f2021-06-29 15:32:19 -0700525 FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700526}
527
528template <DType Dtype>
529int OpTile<1, Dtype>::eval()
530{
531 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
532 {
533 int32_t id0 = od0 % this->in->getShape()[0];
534 this->out->getTensor()(od0) = this->in->getTensor()(id0);
535 }
536
537 return GraphNode::eval();
538}
539
540template <DType Dtype>
541int OpTile<2, Dtype>::eval()
542{
543 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
544 {
545 int32_t id0 = od0 % this->in->getShape()[0];
546 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
547 {
548 int32_t id1 = od1 % this->in->getShape()[1];
549 this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1);
550 }
551 }
552
553 return GraphNode::eval();
554}
555
556template <DType Dtype>
557int OpTile<3, Dtype>::eval()
558{
559 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
560 {
561 int32_t id0 = od0 % this->in->getShape()[0];
562 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
563 {
564 int32_t id1 = od1 % this->in->getShape()[1];
565 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
566 {
567 int32_t id2 = od2 % this->in->getShape()[2];
568 this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2);
569 }
570 }
571 }
572
573 return GraphNode::eval();
574}
575
576template <DType Dtype>
577int OpTile<4, Dtype>::eval()
578{
579 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
580 {
581 int32_t id0 = od0 % this->in->getShape()[0];
582 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
583 {
584 int32_t id1 = od1 % this->in->getShape()[1];
585 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
586 {
587 int32_t id2 = od2 % this->in->getShape()[2];
588 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
589 {
590 int32_t id3 = od3 % this->in->getShape()[3];
591 this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3);
592 }
593 }
594 }
595 }
596
597 return GraphNode::eval();
598}
599
Luke Huttona4e48ca2023-02-22 11:53:48 +0000600template <DType Dtype>
601int OpTile<5, Dtype>::eval()
602{
603 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
604 {
605 int32_t id0 = od0 % this->in->getShape()[0];
606 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
607 {
608 int32_t id1 = od1 % this->in->getShape()[1];
609 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
610 {
611 int32_t id2 = od2 % this->in->getShape()[2];
612 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
613 {
614 int32_t id3 = od3 % this->in->getShape()[3];
615 for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
616 {
617 int32_t id4 = od4 % this->in->getShape()[4];
618 this->out->getTensor()(od0, od1, od2, od3, od4) =
619 this->in->getTensor()(id0, id1, id2, id3, id4);
620 }
621 }
622 }
623 }
624 }
625
626 return GraphNode::eval();
627}
628
629template <DType Dtype>
630int OpTile<6, Dtype>::eval()
631{
632 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
633 {
634 int32_t id0 = od0 % this->in->getShape()[0];
635 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
636 {
637 int32_t id1 = od1 % this->in->getShape()[1];
638 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
639 {
640 int32_t id2 = od2 % this->in->getShape()[2];
641 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
642 {
643 int32_t id3 = od3 % this->in->getShape()[3];
644 for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
645 {
646 int32_t id4 = od4 % this->in->getShape()[4];
647 for (int32_t od5 = 0; od5 < this->out->getShape()[5]; od5++)
648 {
649 int32_t id5 = od5 % this->in->getShape()[5];
650 this->out->getTensor()(od0, od1, od2, od3, od4, od5) =
651 this->in->getTensor()(id0, id1, id2, id3, id4, id5);
652 }
653 }
654 }
655 }
656 }
657 }
658
659 return GraphNode::eval();
660}
661
Eric Kunzee5e26762020-10-13 16:11:07 -0700662template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700663OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
664 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700665 uint64_t id_)
666 : GraphNode(sgt_, Op_TRANSPOSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700667{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000668 setRequiredOperands(1, 1);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000669 setRequiredRank(1, 6);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000670
671 INIT_ATTRIBUTE(Transpose);
Eric Kunzee5e26762020-10-13 16:11:07 -0700672}
673
674template <int Rank, DType Dtype>
675OpTranspose<Rank, Dtype>::~OpTranspose()
Jerry Gea6827492022-11-16 10:41:55 -0800676{
677 if (attribute) delete attribute;
678}
Eric Kunzee5e26762020-10-13 16:11:07 -0700679
680template <int Rank, DType Dtype>
681int OpTranspose<Rank, Dtype>::checkTensorAttributes()
682{
Jerry Gea793f462023-04-11 00:05:02 +0000683 // Check Tosa Level
684 auto tosa_level = g_func_config.tosa_level;
685 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
686
Eric Kunzee5e26762020-10-13 16:11:07 -0700687 if (validateRequiredOperands())
688 return 1;
689
690 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
691 {
692 return 1;
693 }
694
695 // output and input must be the same types
696 if (inputs[0]->matchRankType(*outputs[0]))
697 {
698 printNodeValidationError("Failure to match input and output rank and type");
699 return 1;
700 }
701
702 if (inputs[0]->getElementCount() != outputs[0]->getElementCount())
703 {
704 printNodeValidationError("Failure to match input and output total element count");
705 return 1;
706 }
707
Kevin Chengfe392ce2021-10-18 21:51:55 +0000708 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
709 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
710
711 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700712
TatWai Chong86c403b2022-06-06 20:46:01 -0700713 ERROR_IF(attribute->perms().size() != Rank, "OpTranspose: perms array size needs to match rank(input)");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000714
715 std::array<bool, Rank> index_used;
716 index_used.fill(false);
717 for (int32_t d = 0; d < Rank; d++)
718 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700719 int32_t index = attribute->perms()[d];
Kevin Chengf3e016f2021-11-02 01:15:50 +0000720 ERROR_IF(index < 0 or index >= Rank, "OpTranspose: index out of boundary");
721 ERROR_IF(index_used[index], "OpTranspose: index duplicated in perm attribute");
722 index_used[index] = true;
723 ERROR_IF(in->getShape()[index] != out->getShape()[d], "OpTranspose: input output shape mismatch");
724 perm_array[d] = index;
725 }
726
Eric Kunzee5e26762020-10-13 16:11:07 -0700727 return 0;
728}
729
730template <int Rank, DType Dtype>
731int OpTranspose<Rank, Dtype>::eval()
732{
Eric Kunzee5e26762020-10-13 16:11:07 -0700733 out->getTensor() = in->getTensor().shuffle(perm_array);
734
735 return GraphNode::eval();
736}
737
738// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +0100739DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16)
James Ward24dbc422022-10-19 12:20:31 +0100740DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100741DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700742DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
743DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
744DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
745DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
746
James Ward8b390432022-08-12 20:48:56 +0100747DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100748DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100749DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700750DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
751DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
752DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
753DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
754
James Ward8b390432022-08-12 20:48:56 +0100755DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100756DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100757DEF_INSTANTIATE_RESHAPE(OpReshape, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700758DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
759DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
760DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
761DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
762
James Ward8b390432022-08-12 20:48:56 +0100763DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100764DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100765DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700766DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
767DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
768DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
769DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
770
Luke Huttona4e48ca2023-02-22 11:53:48 +0000771DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
772DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
773DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
774DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
775DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
776DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
777DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
Eric Kunzee5e26762020-10-13 16:11:07 -0700778
Luke Huttona4e48ca2023-02-22 11:53:48 +0000779DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
780DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
781DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32);
782DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
783DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
784DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
785DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
Jared Smolens98c281f2022-12-20 15:09:25 -0800786
Luke Huttona4e48ca2023-02-22 11:53:48 +0000787DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
788DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
789DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
790DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
791DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
792DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
793DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
Eric Kunzee5e26762020-10-13 16:11:07 -0700794
Luke Huttona4e48ca2023-02-22 11:53:48 +0000795DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
796DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
797DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
798DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
799DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
800DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
801DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);