blob: e264284adb03bce577883102fecf45e2db711425 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Tai Ly8690a082023-12-18 20:40:24 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "data_layout.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
Tai Lya4d748b2023-03-28 22:06:56 +000023template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +000024OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070025 : GraphNode(sgt_, Op_CONCAT, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070026{
Kevin Chengad15dfa2021-03-04 15:15:03 -080027 setRequiredOperands(-1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +000028 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -070029
30 INIT_ATTRIBUTE(Axis);
31}
32
Tai Lya4d748b2023-03-28 22:06:56 +000033template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070034OpConcat<Rank, Dtype>::~OpConcat()
35{
36 if (attribute)
37 delete attribute;
38}
39
Tai Lya4d748b2023-03-28 22:06:56 +000040template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070041int OpConcat<Rank, Dtype>::checkTensorAttributes()
42{
Jerry Gea793f462023-04-11 00:05:02 +000043 // Check Tosa Level
44 auto tosa_level = g_func_config.tosa_level;
45 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
46
Eric Kunzee5e26762020-10-13 16:11:07 -070047 if (validateRequiredOperands())
48 return 1;
49
Kevin Chengad15dfa2021-03-04 15:15:03 -080050 if (inputs.empty())
Eric Kunzee5e26762020-10-13 16:11:07 -070051 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080052 printNodeValidationError("Concat operator must have at least one input tensor");
Eric Kunzee5e26762020-10-13 16:11:07 -070053 return 1;
54 }
Kevin Chengcc61be32021-10-14 17:09:57 -070055
56 int32_t num_inputs = inputs.size();
57
Eric Kunzee5e26762020-10-13 16:11:07 -070058 // output and input must be the same types and rank
Kevin Chengcc61be32021-10-14 17:09:57 -070059 for (int32_t i = 0; i < num_inputs; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070060 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080061 if (inputs[i]->matchRankType(*outputs[0]))
62 {
Kevin Chengcc61be32021-10-14 17:09:57 -070063 printNodeValidationError("OpConcat: input ranks and types must match");
Kevin Chengad15dfa2021-03-04 15:15:03 -080064 return 1;
65 }
66 ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
Eric Kunzee5e26762020-10-13 16:11:07 -070067 }
68
Kevin Chengcc61be32021-10-14 17:09:57 -070069 if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
Eric Kunzee5e26762020-10-13 16:11:07 -070070 {
Kevin Chengcc61be32021-10-14 17:09:57 -070071 printNodeValidationError("OpConcat: axis is beyond output tensor rank");
Eric Kunzee5e26762020-10-13 16:11:07 -070072 return 1;
73 }
74
Kevin Chengcc61be32021-10-14 17:09:57 -070075 int32_t output_dim_on_axis = 0;
76 for (int32_t j = 0; j < num_inputs; j++)
77 {
78 for (int32_t i = 0; i < Rank; i++)
79 {
80 int32_t input_dim = inputs[j]->getShape()[i];
81 if (i == attribute->axis())
82 {
83 output_dim_on_axis += input_dim;
84 }
85 else if (input_dim != outputs[0]->getShape()[i])
86 {
87 printNodeValidationError("OpConcat: input dimension not matching output dimension");
88 return 1;
89 }
90 }
91 }
92
Kevin Cheng6e528662021-10-20 17:35:33 +000093 ERROR_IF(output_dim_on_axis != outputs[0]->getShape()[attribute->axis()],
Kevin Chengcc61be32021-10-14 17:09:57 -070094 "OpConcat: sum of input dimension on axis not equal to output dimension on axis");
95
96 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
97
Eric Kunzee5e26762020-10-13 16:11:07 -070098 return 0;
99}
100
Tai Lya4d748b2023-03-28 22:06:56 +0000101template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700102int OpConcat<Rank, Dtype>::eval()
103{
104
105 int32_t reversed_axis = Rank - 1 - attribute->axis();
106
107 for (int32_t d = 0; d < Rank; d++)
108 {
109 reverser[d] = Rank - 1 - d;
110 }
111
Kevin Chengad15dfa2021-03-04 15:15:03 -0800112 TIn result = ins[0]->getTensor().shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700113
Kevin Chengad15dfa2021-03-04 15:15:03 -0800114 for (size_t i = 1; i < ins.size(); i++)
115 {
116 TIn in_reversed = ins[i]->getTensor().shuffle(reverser);
117 TIn temp = result.concatenate(in_reversed, reversed_axis);
118 result = temp;
119 }
120 out->getTensor() = result.shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 return GraphNode::eval();
123}
124
Tai Lya4d748b2023-03-28 22:06:56 +0000125template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000126OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700127 : GraphNode(sgt_, Op_PAD, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700128{
Tai Lye095da72024-01-25 22:00:18 +0000129 setRequiredOperands(2, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000130 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700131
Kevin Chengfe392ce2021-10-18 21:51:55 +0000132 INIT_ATTRIBUTE(Pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700133}
134
Tai Lya4d748b2023-03-28 22:06:56 +0000135template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700136OpPad<Rank, Dtype>::~OpPad()
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000137{}
Eric Kunzee5e26762020-10-13 16:11:07 -0700138
Tai Lya4d748b2023-03-28 22:06:56 +0000139template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700140int OpPad<Rank, Dtype>::checkTensorAttributes()
141{
Jerry Gea793f462023-04-11 00:05:02 +0000142 // Check Tosa Level
143 auto tosa_level = g_func_config.tosa_level;
144 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
145
Eric Kunzee5e26762020-10-13 16:11:07 -0700146 if (validateRequiredOperands())
147 return 1;
148
149 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
150 {
151 return 1;
152 }
153
154 // output and input must be the same types
155 if (inputs[0]->matchRankType(*outputs[0]))
156 {
157 printNodeValidationError("Failure to match input and output type and rank");
158 return 1;
159 }
160
Tai Lye095da72024-01-25 22:00:18 +0000161 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
162 padding = dynamic_cast<TosaReference::TensorTemplate<TPadding>*>(inputs[1]);
163 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000164 ASSERT_MEM(in && out);
165
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 return 0;
167}
168
Tai Lya4d748b2023-03-28 22:06:56 +0000169template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700170int OpPad<Rank, Dtype>::eval()
171{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000172 InEigenType pad_value = 0;
173
174 switch (Dtype)
Kevin Chengcc61be32021-10-14 17:09:57 -0700175 {
Tai Lya4d748b2023-03-28 22:06:56 +0000176 case TOSA_REF_TYPE_BOOL:
177 case TOSA_REF_TYPE_INT8:
178 case TOSA_REF_TYPE_INT16:
Tai Ly60dc48c2024-03-08 22:19:41 +0000179 case TOSA_REF_TYPE_INT32: {
180 std::vector<int32_t> int32_data;
181 TosaSerializationHandler::ConvertU8toI32(attribute->pad_const(),
182 /* size = */ 1, int32_data);
183 pad_value = (InEigenType)int32_data[0];
Kevin Chengfe392ce2021-10-18 21:51:55 +0000184 break;
Tai Ly60dc48c2024-03-08 22:19:41 +0000185 }
Tai Lya4d748b2023-03-28 22:06:56 +0000186 case TOSA_REF_TYPE_FP16:
187 case TOSA_REF_TYPE_BF16:
188 case TOSA_REF_TYPE_FP32:
189 case TOSA_REF_TYPE_FP64:
Won Jeon3195a662024-02-27 17:52:45 +0000190 case TOSA_REF_TYPE_FP8E4M3:
Tai Ly60dc48c2024-03-08 22:19:41 +0000191 case TOSA_REF_TYPE_FP8E5M2: {
192 std::vector<float> float_data;
193 TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(),
194 /* size = */ 1, float_data);
195 pad_value = (InEigenType)float_data[0];
Kevin Chengfe392ce2021-10-18 21:51:55 +0000196 break;
Tai Ly60dc48c2024-03-08 22:19:41 +0000197 }
TatWai Chong86c403b2022-06-06 20:46:01 -0700198 default:
Won Jeon3195a662024-02-27 17:52:45 +0000199 ASSERT_MSG(false, "TOSA_REF_TYPE %s is not supported.", EnumNameTOSAREFTYPE(Dtype));
TatWai Chong86c403b2022-06-06 20:46:01 -0700200 break;
Kevin Chengcc61be32021-10-14 17:09:57 -0700201 }
202
Tai Lye095da72024-01-25 22:00:18 +0000203 // padding is an 1D array of [Rank * 2], with ordering:
204 // [Rank0_front, Rank0_back, Rank1_front, Rank1_back, ..., Rank(N-1)_front, Rank(N-1)_back]
205 TPadding padding_val = this->padding->getTensor();
206 ERROR_IF(padding_val.size() != (Rank * 2), "OpPad: padding length needs to be (rank(input1) * 2)");
207 for (int i = 0; i < Rank; i++)
208 {
209 auto pad_front = padding_val(2 * i);
210 auto pad_back = padding_val(2 * i + 1);
211 ERROR_IF((pad_front < 0) || (pad_back < 0), "OpPad: padding can't be smaller than 0");
212 ERROR_IF(out->getShape()[i] != pad_front + in->getShape()[i] + pad_back,
213 "OpPad: output shape not equal to input plus padding");
214 paddings_array[i] = std::make_pair(pad_front, pad_back);
215 }
216
Eric Kunzee5e26762020-10-13 16:11:07 -0700217 this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
218
219 return GraphNode::eval();
220}
221
Won Jeona21b2e82023-08-10 10:33:01 +0000222template <int Rank, TOSA_REF_TYPE Dtype>
223OpDim<Rank, Dtype>::OpDim(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
224 : GraphNode(sgt_, Op_DIM, id_)
225{
226 setRequiredOperands(1, 1);
227
228 INIT_ATTRIBUTE(Axis);
229}
230
231template <int Rank, TOSA_REF_TYPE Dtype>
232OpDim<Rank, Dtype>::~OpDim()
233{
234 if (attribute)
235 delete attribute;
236}
237
238template <int Rank, TOSA_REF_TYPE Dtype>
239int OpDim<Rank, Dtype>::checkTensorAttributes()
240{
241 // Check Tosa Level
242 auto tosa_level = g_func_config.tosa_level;
243 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
244
245 if (validateRequiredOperands())
246 return 1;
247
248 if (validateRequiredRank(inputs[0]))
249 return 1;
250
251 if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
252 {
253 printNodeValidationError("OpDim: axis must between [0, input_rank - 1]");
254 return 1;
255 }
256
257 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
258 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
259
260 ASSERT_MEM(in && out);
261
262 return 0;
263}
264
265template <int Rank, TOSA_REF_TYPE Dtype>
266int OpDim<Rank, Dtype>::eval()
267{
268 int32_t axis = attribute->axis();
269 int64_t out_val = in->getShape()[axis];
270
Tai Ly8690a082023-12-18 20:40:24 +0000271 this->out->getTensor().setValues({ out_val });
Won Jeona21b2e82023-08-10 10:33:01 +0000272
Jerry Ge12159fc2024-04-01 17:05:10 +0000273 // set the shapeValue given the actual tensor value
274 std::vector<int> shapeValue;
275 for (int i = 0; i < out->getTensor().size(); ++i)
276 {
277 shapeValue.push_back(out->getTensor()(i));
278 }
279
280 this->getOutputs()[0]->setShapeValue(shapeValue);
281
Won Jeona21b2e82023-08-10 10:33:01 +0000282 return GraphNode::eval();
283}
284
Tai Lya4d748b2023-03-28 22:06:56 +0000285template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000286OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700287 : GraphNode(sgt_, Op_RESHAPE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700288{
Tai Ly8690a082023-12-18 20:40:24 +0000289 setRequiredOperands(2, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700290}
291
Tai Lya4d748b2023-03-28 22:06:56 +0000292template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700293OpReshape<InRank, OutRank, Dtype>::~OpReshape()
Tai Ly8690a082023-12-18 20:40:24 +0000294{}
Eric Kunzee5e26762020-10-13 16:11:07 -0700295
Tai Lya4d748b2023-03-28 22:06:56 +0000296template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700297int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
298{
Jerry Gea793f462023-04-11 00:05:02 +0000299 // Check Tosa Level
300 auto tosa_level = g_func_config.tosa_level;
301 LEVEL_CHECK(InRank <= tosa_level.MAX_RANK, "InRank should be smaller than or equal to MAX_RANK");
302 LEVEL_CHECK(OutRank <= tosa_level.MAX_RANK, "OutRank should be smaller than or equal to MAX_RANK");
303
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 if (validateRequiredOperands())
305 return 1;
306
Eric Kunzee5e26762020-10-13 16:11:07 -0700307 // output and input must be the same types
308 if (inputs[0]->matchType(*outputs[0]))
309 {
310 printNodeValidationError("OpReshape: Input and output types must match");
311 return 1;
312 }
313
Kevin Chengcc61be32021-10-14 17:09:57 -0700314 ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
315 "Input tensor size does not match output tensor size");
316
Eric Kunzee5e26762020-10-13 16:11:07 -0700317 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
318 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
319
Tai Ly8690a082023-12-18 20:40:24 +0000320 // note: do not assert mem on shape input, because it may be {} for reshape to scalar
321 // and also, because the shape input is not actually used in eval()
322
323 ASSERT_MEM(in && out)
324
Eric Kunzee5e26762020-10-13 16:11:07 -0700325 return 0;
326}
327
Tai Lya4d748b2023-03-28 22:06:56 +0000328template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700329int OpReshape<InRank, OutRank, Dtype>::eval()
330{
Eric Kunzee5e26762020-10-13 16:11:07 -0700331 for (int32_t d = 0; d < OutRank; d++)
332 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000333 array_shape[d] = getOutputs()[0]->getShape()[OutRank - 1 - d];
Eric Kunzee5e26762020-10-13 16:11:07 -0700334 out_reverser[d] = OutRank - 1 - d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700335 }
336
337 for (int32_t d = 0; d < InRank; d++)
338 {
339 in_reverser[d] = InRank - 1 - d;
340 }
341
342 // Eigen Tensor is col-major, and we're referencing row-major result
343 // need to reverse it to row-major before reshape, and perform another reverse afterward
344
345 // input tensor rank 0 can't do .shuffle(), need to be handled otherwise
346 TIn in_reversed;
347 if (InRank > 1)
348 {
349 in_reversed = in->getTensor().shuffle(in_reverser);
350 }
351 else
352 {
353 in_reversed = in->getTensor();
354 }
355
356 TOut in_reshaped = in_reversed.reshape(array_shape);
357
358 // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise
359 if (OutRank > 1)
360 {
361 out->getTensor() = in_reshaped.shuffle(out_reverser);
362 }
363 else
364 {
365 out->getTensor() = in_reshaped;
366 }
367
368 return GraphNode::eval();
369}
370
Tai Lya4d748b2023-03-28 22:06:56 +0000371template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000372OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700373 : GraphNode(sgt_, Op_REVERSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700374{
375 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000376 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700377
378 INIT_ATTRIBUTE(Axis);
379}
380
Tai Lya4d748b2023-03-28 22:06:56 +0000381template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700382OpReverse<Rank, Dtype>::~OpReverse()
383{
384 if (attribute)
385 delete attribute;
386}
387
Tai Lya4d748b2023-03-28 22:06:56 +0000388template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700389int OpReverse<Rank, Dtype>::checkTensorAttributes()
390{
Jerry Gea793f462023-04-11 00:05:02 +0000391 // Check Tosa Level
392 auto tosa_level = g_func_config.tosa_level;
393 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
394
Eric Kunzee5e26762020-10-13 16:11:07 -0700395 if (validateRequiredOperands())
396 return 1;
397
398 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
399 {
400 return 1;
401 }
402
403 // output and input must be the same types
404 if (inputs[0]->matchRankTypeShape(*outputs[0]))
405 {
406 printNodeValidationError("Failure to match input and output rank/type/shape");
407 return 1;
408 }
409
410 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
411 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
412
413 ASSERT_MEM(in && out);
414
415 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
416 {
417 printNodeValidationError("Reverse axis must between [0, input_rank - 1]");
418 return 1;
419 }
420
421 // transform list of axis into true or false list
422 // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false]
423 for (int i = 0; i < Rank; i++)
424 {
425 reverse_array[i] = false;
426 }
427 reverse_array[attribute->axis()] = true;
428
429 return 0;
430}
431
Tai Lya4d748b2023-03-28 22:06:56 +0000432template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700433int OpReverse<Rank, Dtype>::eval()
434{
435 out->getTensor() = in->getTensor().reverse(reverse_array);
436
437 return GraphNode::eval();
438}
439
Tai Lya4d748b2023-03-28 22:06:56 +0000440template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000441OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700442 : GraphNode(sgt_, Op_SLICE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700443{
TatWai Chong01f937a2024-01-24 22:57:07 -0800444 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000445 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700446}
447
Tai Lya4d748b2023-03-28 22:06:56 +0000448template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700449OpSlice<Rank, Dtype>::~OpSlice()
TatWai Chong01f937a2024-01-24 22:57:07 -0800450{}
Eric Kunzee5e26762020-10-13 16:11:07 -0700451
Tai Lya4d748b2023-03-28 22:06:56 +0000452template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700453int OpSlice<Rank, Dtype>::checkTensorAttributes()
454{
Jerry Gea793f462023-04-11 00:05:02 +0000455 // Check Tosa Level
456 auto tosa_level = g_func_config.tosa_level;
457 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
458
Eric Kunzee5e26762020-10-13 16:11:07 -0700459 if (validateRequiredOperands())
460 return 1;
461
462 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
463 {
464 return 1;
465 }
466
467 // output and input must be the same types
Luke Huttona4e48ca2023-02-22 11:53:48 +0000468 if (inputs[0]->matchRankType(*outputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700469 {
Luke Huttona4e48ca2023-02-22 11:53:48 +0000470 printNodeValidationError("Failure to match input and output rank or type");
Eric Kunzee5e26762020-10-13 16:11:07 -0700471 return 1;
472 }
473
TatWai Chong01f937a2024-01-24 22:57:07 -0800474 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
TatWai Chong97f1c0e2024-02-05 11:56:46 -0800475 start = dynamic_cast<TosaReference::TensorTemplate<TSlicing>*>(inputs[1]);
476 size = dynamic_cast<TosaReference::TensorTemplate<TSlicing>*>(inputs[2]);
TatWai Chong01f937a2024-01-24 22:57:07 -0800477 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700478
TatWai Chong97f1c0e2024-02-05 11:56:46 -0800479 ASSERT_MEM(in && out && start && size);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000480
TatWai Chong01f937a2024-01-24 22:57:07 -0800481 return 0;
482}
483
484template <int Rank, TOSA_REF_TYPE Dtype>
485int OpSlice<Rank, Dtype>::eval()
486{
TatWai Chong97f1c0e2024-02-05 11:56:46 -0800487 TSlicing start_tensor = start->getTensor();
488 TSlicing size_tensor = size->getTensor();
489
490 // According to https://eigen.tuxfamily.org/dox/unsupported/eigen_tensors.html
491 // The type of size() is <Tensor-Type>::Index, but can always handily use it like an int.
492 // However, apply explicit cast to int32_t is preferred.
493 ERROR_IF(static_cast<int32_t>(start_tensor.size()) != in->getRank(),
494 "OpSlice: start array length needs to be rank(input)");
495 ERROR_IF(static_cast<int32_t>(size_tensor.size()) != in->getRank(),
496 "OpSlice: size array length needs to be rank(input)");
Eric Kunzee5e26762020-10-13 16:11:07 -0700497
Kevin Chengcc61be32021-10-14 17:09:57 -0700498 for (int32_t i = 0; i < in->getRank(); i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700499 {
TatWai Chong97f1c0e2024-02-05 11:56:46 -0800500 int32_t b = start_tensor(i);
501 int32_t s = size_tensor(i);
Kevin Chengcc61be32021-10-14 17:09:57 -0700502 ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary");
503 ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary");
504 ERROR_IF(s <= 0, "OpSlice: output must be positive");
505 ERROR_IF(s != out->getShape()[i], "OpSlice: size doesn't match output tensor dimension");
506 begin_array[i] = b;
507 size_array[i] = s;
Eric Kunzee5e26762020-10-13 16:11:07 -0700508 }
509
Eric Kunzee5e26762020-10-13 16:11:07 -0700510 out->getTensor() = in->getTensor().slice(begin_array, size_array);
511
512 return GraphNode::eval();
513}
514
Tai Lya4d748b2023-03-28 22:06:56 +0000515template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000516OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700517 : GraphNode(sgt_, Op_TILE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700518{
Tai Ly8690a082023-12-18 20:40:24 +0000519 setRequiredOperands(2, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000520 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700521}
522
Tai Lya4d748b2023-03-28 22:06:56 +0000523template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700524OpTileBase<Rank, Dtype>::~OpTileBase()
Tai Ly8690a082023-12-18 20:40:24 +0000525{}
Eric Kunzee5e26762020-10-13 16:11:07 -0700526
Tai Lya4d748b2023-03-28 22:06:56 +0000527template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700528int OpTileBase<Rank, Dtype>::checkTensorAttributes()
529{
Jerry Gea793f462023-04-11 00:05:02 +0000530 // Check Tosa Level
531 auto tosa_level = g_func_config.tosa_level;
532 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
533
Eric Kunzee5e26762020-10-13 16:11:07 -0700534 if (validateRequiredOperands())
535 return 1;
536
537 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
538 {
539 return 1;
540 }
541
542 // output and input must be the same ranks and types
543 if (inputs[0]->matchRankType(*outputs[0]))
544 {
545 printNodeValidationError("Failure to match input and output rank or type");
546 return 1;
547 }
548
Tai Ly8690a082023-12-18 20:40:24 +0000549 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
550 multiples = dynamic_cast<TosaReference::TensorTemplate<TInMultiples>*>(inputs[1]);
551 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700552
Tai Ly8690a082023-12-18 20:40:24 +0000553 ASSERT_MEM(in && multiples && out);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000554
Tai Ly8690a082023-12-18 20:40:24 +0000555 if (multiples->getElementCount() != Rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700556 {
557 printNodeValidationError("1D list 'multiples' must have size equal to input rank");
558 return 1;
559 }
560
Eric Kunzee5e26762020-10-13 16:11:07 -0700561 return 0;
562}
563
Tai Lya4d748b2023-03-28 22:06:56 +0000564template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700565int OpTile<Rank, Dtype>::eval()
566{
567 // primary template shouldn't be called
Tai Lya4d748b2023-03-28 22:06:56 +0000568 FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700569}
570
Tai Lya4d748b2023-03-28 22:06:56 +0000571template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700572int OpTile<1, Dtype>::eval()
573{
574 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
575 {
576 int32_t id0 = od0 % this->in->getShape()[0];
577 this->out->getTensor()(od0) = this->in->getTensor()(id0);
578 }
579
580 return GraphNode::eval();
581}
582
Tai Lya4d748b2023-03-28 22:06:56 +0000583template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700584int OpTile<2, Dtype>::eval()
585{
586 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
587 {
588 int32_t id0 = od0 % this->in->getShape()[0];
589 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
590 {
591 int32_t id1 = od1 % this->in->getShape()[1];
592 this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1);
593 }
594 }
595
596 return GraphNode::eval();
597}
598
Tai Lya4d748b2023-03-28 22:06:56 +0000599template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700600int OpTile<3, Dtype>::eval()
601{
602 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
603 {
604 int32_t id0 = od0 % this->in->getShape()[0];
605 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
606 {
607 int32_t id1 = od1 % this->in->getShape()[1];
608 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
609 {
610 int32_t id2 = od2 % this->in->getShape()[2];
611 this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2);
612 }
613 }
614 }
615
616 return GraphNode::eval();
617}
618
Tai Lya4d748b2023-03-28 22:06:56 +0000619template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700620int OpTile<4, Dtype>::eval()
621{
622 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
623 {
624 int32_t id0 = od0 % this->in->getShape()[0];
625 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
626 {
627 int32_t id1 = od1 % this->in->getShape()[1];
628 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
629 {
630 int32_t id2 = od2 % this->in->getShape()[2];
631 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
632 {
633 int32_t id3 = od3 % this->in->getShape()[3];
634 this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3);
635 }
636 }
637 }
638 }
639
640 return GraphNode::eval();
641}
642
Tai Lya4d748b2023-03-28 22:06:56 +0000643template <TOSA_REF_TYPE Dtype>
Luke Huttona4e48ca2023-02-22 11:53:48 +0000644int OpTile<5, Dtype>::eval()
645{
646 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
647 {
648 int32_t id0 = od0 % this->in->getShape()[0];
649 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
650 {
651 int32_t id1 = od1 % this->in->getShape()[1];
652 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
653 {
654 int32_t id2 = od2 % this->in->getShape()[2];
655 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
656 {
657 int32_t id3 = od3 % this->in->getShape()[3];
658 for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
659 {
660 int32_t id4 = od4 % this->in->getShape()[4];
661 this->out->getTensor()(od0, od1, od2, od3, od4) =
662 this->in->getTensor()(id0, id1, id2, id3, id4);
663 }
664 }
665 }
666 }
667 }
668
669 return GraphNode::eval();
670}
671
Tai Lya4d748b2023-03-28 22:06:56 +0000672template <TOSA_REF_TYPE Dtype>
Luke Huttona4e48ca2023-02-22 11:53:48 +0000673int OpTile<6, Dtype>::eval()
674{
675 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
676 {
677 int32_t id0 = od0 % this->in->getShape()[0];
678 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
679 {
680 int32_t id1 = od1 % this->in->getShape()[1];
681 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
682 {
683 int32_t id2 = od2 % this->in->getShape()[2];
684 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
685 {
686 int32_t id3 = od3 % this->in->getShape()[3];
687 for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
688 {
689 int32_t id4 = od4 % this->in->getShape()[4];
690 for (int32_t od5 = 0; od5 < this->out->getShape()[5]; od5++)
691 {
692 int32_t id5 = od5 % this->in->getShape()[5];
693 this->out->getTensor()(od0, od1, od2, od3, od4, od5) =
694 this->in->getTensor()(id0, id1, id2, id3, id4, id5);
695 }
696 }
697 }
698 }
699 }
700 }
701
702 return GraphNode::eval();
703}
704
Tai Lya4d748b2023-03-28 22:06:56 +0000705template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000706OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700707 : GraphNode(sgt_, Op_TRANSPOSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700708{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000709 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000710 setRequiredRank(1);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000711
712 INIT_ATTRIBUTE(Transpose);
Eric Kunzee5e26762020-10-13 16:11:07 -0700713}
714
Tai Lya4d748b2023-03-28 22:06:56 +0000715template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700716OpTranspose<Rank, Dtype>::~OpTranspose()
Jerry Gea6827492022-11-16 10:41:55 -0800717{
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000718 if (attribute)
719 delete attribute;
Jerry Gea6827492022-11-16 10:41:55 -0800720}
Eric Kunzee5e26762020-10-13 16:11:07 -0700721
Tai Lya4d748b2023-03-28 22:06:56 +0000722template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700723int OpTranspose<Rank, Dtype>::checkTensorAttributes()
724{
Jerry Gea793f462023-04-11 00:05:02 +0000725 // Check Tosa Level
726 auto tosa_level = g_func_config.tosa_level;
727 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
728
Eric Kunzee5e26762020-10-13 16:11:07 -0700729 if (validateRequiredOperands())
730 return 1;
731
732 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
733 {
734 return 1;
735 }
736
737 // output and input must be the same types
738 if (inputs[0]->matchRankType(*outputs[0]))
739 {
740 printNodeValidationError("Failure to match input and output rank and type");
741 return 1;
742 }
743
744 if (inputs[0]->getElementCount() != outputs[0]->getElementCount())
745 {
746 printNodeValidationError("Failure to match input and output total element count");
747 return 1;
748 }
749
Kevin Chengfe392ce2021-10-18 21:51:55 +0000750 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
751 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
752
753 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700754
TatWai Chong86c403b2022-06-06 20:46:01 -0700755 ERROR_IF(attribute->perms().size() != Rank, "OpTranspose: perms array size needs to match rank(input)");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000756
757 std::array<bool, Rank> index_used;
758 index_used.fill(false);
759 for (int32_t d = 0; d < Rank; d++)
760 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700761 int32_t index = attribute->perms()[d];
Kevin Chengf3e016f2021-11-02 01:15:50 +0000762 ERROR_IF(index < 0 or index >= Rank, "OpTranspose: index out of boundary");
763 ERROR_IF(index_used[index], "OpTranspose: index duplicated in perm attribute");
764 index_used[index] = true;
765 ERROR_IF(in->getShape()[index] != out->getShape()[d], "OpTranspose: input output shape mismatch");
766 perm_array[d] = index;
767 }
768
Eric Kunzee5e26762020-10-13 16:11:07 -0700769 return 0;
770}
771
Tai Lya4d748b2023-03-28 22:06:56 +0000772template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700773int OpTranspose<Rank, Dtype>::eval()
774{
Eric Kunzee5e26762020-10-13 16:11:07 -0700775 out->getTensor() = in->getTensor().shuffle(perm_array);
776
777 return GraphNode::eval();
778}
779
780// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +0100781DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16)
James Ward24dbc422022-10-19 12:20:31 +0100782DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100783DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700784DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
785DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
786DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
787DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
Tai Lya4d748b2023-03-28 22:06:56 +0000788DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64)
Won Jeon2c34b462024-02-06 18:37:00 +0000789DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3);
790DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700791
James Ward8b390432022-08-12 20:48:56 +0100792DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100793DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100794DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700795DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
796DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
797DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
798DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000799DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000800DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3);
801DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700802
Won Jeona21b2e82023-08-10 10:33:01 +0000803DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16);
804DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BF16);
805DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP32);
806DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8);
807DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16);
808DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT32);
809DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL);
Won Jeon2c34b462024-02-06 18:37:00 +0000810DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3);
811DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2);
Won Jeona21b2e82023-08-10 10:33:01 +0000812
James Ward8b390432022-08-12 20:48:56 +0100813DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100814DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100815DEF_INSTANTIATE_RESHAPE(OpReshape, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700816DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
817DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
818DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
819DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000820DEF_INSTANTIATE_RESHAPE(OpReshape, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000821DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E4M3);
822DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700823
James Ward8b390432022-08-12 20:48:56 +0100824DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100825DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100826DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700827DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
828DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
829DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
830DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000831DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000832DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3);
833DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700834
Luke Huttona4e48ca2023-02-22 11:53:48 +0000835DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
836DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
837DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
838DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
839DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
840DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
841DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000842DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000843DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3);
844DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700845
Luke Huttona4e48ca2023-02-22 11:53:48 +0000846DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
847DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
848DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32);
849DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
850DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
851DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
852DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000853DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000854DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E4M3);
855DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E5M2);
Jared Smolens98c281f2022-12-20 15:09:25 -0800856
Luke Huttona4e48ca2023-02-22 11:53:48 +0000857DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
858DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
859DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
860DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
861DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
862DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
863DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000864DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000865DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3);
866DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700867
Luke Huttona4e48ca2023-02-22 11:53:48 +0000868DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
869DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
870DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
871DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
872DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
873DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
874DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000875DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000876DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3);
877DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2);