blob: 6664ec390785a879cbcca7ac16a61a07e2ffbc1b [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Tai Ly8690a082023-12-18 20:40:24 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "data_layout.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
Tai Lya4d748b2023-03-28 22:06:56 +000023template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +000024OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070025 : GraphNode(sgt_, Op_CONCAT, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070026{
Kevin Chengad15dfa2021-03-04 15:15:03 -080027 setRequiredOperands(-1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +000028 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -070029
30 INIT_ATTRIBUTE(Axis);
31}
32
Tai Lya4d748b2023-03-28 22:06:56 +000033template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070034OpConcat<Rank, Dtype>::~OpConcat()
35{
36 if (attribute)
37 delete attribute;
38}
39
Tai Lya4d748b2023-03-28 22:06:56 +000040template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070041int OpConcat<Rank, Dtype>::checkTensorAttributes()
42{
Jerry Gea793f462023-04-11 00:05:02 +000043 // Check Tosa Level
44 auto tosa_level = g_func_config.tosa_level;
45 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
46
Eric Kunzee5e26762020-10-13 16:11:07 -070047 if (validateRequiredOperands())
48 return 1;
49
Kevin Chengad15dfa2021-03-04 15:15:03 -080050 if (inputs.empty())
Eric Kunzee5e26762020-10-13 16:11:07 -070051 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080052 printNodeValidationError("Concat operator must have at least one input tensor");
Eric Kunzee5e26762020-10-13 16:11:07 -070053 return 1;
54 }
Kevin Chengcc61be32021-10-14 17:09:57 -070055
56 int32_t num_inputs = inputs.size();
57
Eric Kunzee5e26762020-10-13 16:11:07 -070058 // output and input must be the same types and rank
Kevin Chengcc61be32021-10-14 17:09:57 -070059 for (int32_t i = 0; i < num_inputs; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -070060 {
Kevin Chengad15dfa2021-03-04 15:15:03 -080061 if (inputs[i]->matchRankType(*outputs[0]))
62 {
Kevin Chengcc61be32021-10-14 17:09:57 -070063 printNodeValidationError("OpConcat: input ranks and types must match");
Kevin Chengad15dfa2021-03-04 15:15:03 -080064 return 1;
65 }
66 ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
Eric Kunzee5e26762020-10-13 16:11:07 -070067 }
68
Kevin Chengcc61be32021-10-14 17:09:57 -070069 if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
Eric Kunzee5e26762020-10-13 16:11:07 -070070 {
Kevin Chengcc61be32021-10-14 17:09:57 -070071 printNodeValidationError("OpConcat: axis is beyond output tensor rank");
Eric Kunzee5e26762020-10-13 16:11:07 -070072 return 1;
73 }
74
Kevin Chengcc61be32021-10-14 17:09:57 -070075 int32_t output_dim_on_axis = 0;
76 for (int32_t j = 0; j < num_inputs; j++)
77 {
78 for (int32_t i = 0; i < Rank; i++)
79 {
80 int32_t input_dim = inputs[j]->getShape()[i];
81 if (i == attribute->axis())
82 {
83 output_dim_on_axis += input_dim;
84 }
85 else if (input_dim != outputs[0]->getShape()[i])
86 {
87 printNodeValidationError("OpConcat: input dimension not matching output dimension");
88 return 1;
89 }
90 }
91 }
92
Kevin Cheng6e528662021-10-20 17:35:33 +000093 ERROR_IF(output_dim_on_axis != outputs[0]->getShape()[attribute->axis()],
Kevin Chengcc61be32021-10-14 17:09:57 -070094 "OpConcat: sum of input dimension on axis not equal to output dimension on axis");
95
96 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
97
Eric Kunzee5e26762020-10-13 16:11:07 -070098 return 0;
99}
100
Tai Lya4d748b2023-03-28 22:06:56 +0000101template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700102int OpConcat<Rank, Dtype>::eval()
103{
104
105 int32_t reversed_axis = Rank - 1 - attribute->axis();
106
107 for (int32_t d = 0; d < Rank; d++)
108 {
109 reverser[d] = Rank - 1 - d;
110 }
111
Kevin Chengad15dfa2021-03-04 15:15:03 -0800112 TIn result = ins[0]->getTensor().shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700113
Kevin Chengad15dfa2021-03-04 15:15:03 -0800114 for (size_t i = 1; i < ins.size(); i++)
115 {
116 TIn in_reversed = ins[i]->getTensor().shuffle(reverser);
117 TIn temp = result.concatenate(in_reversed, reversed_axis);
118 result = temp;
119 }
120 out->getTensor() = result.shuffle(reverser);
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 return GraphNode::eval();
123}
124
Tai Lya4d748b2023-03-28 22:06:56 +0000125template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000126OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700127 : GraphNode(sgt_, Op_PAD, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700128{
Tai Lye095da72024-01-25 22:00:18 +0000129 setRequiredOperands(2, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000130 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700131
Kevin Chengfe392ce2021-10-18 21:51:55 +0000132 INIT_ATTRIBUTE(Pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700133}
134
Tai Lya4d748b2023-03-28 22:06:56 +0000135template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700136OpPad<Rank, Dtype>::~OpPad()
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000137{}
Eric Kunzee5e26762020-10-13 16:11:07 -0700138
Tai Lya4d748b2023-03-28 22:06:56 +0000139template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700140int OpPad<Rank, Dtype>::checkTensorAttributes()
141{
Jerry Gea793f462023-04-11 00:05:02 +0000142 // Check Tosa Level
143 auto tosa_level = g_func_config.tosa_level;
144 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
145
Eric Kunzee5e26762020-10-13 16:11:07 -0700146 if (validateRequiredOperands())
147 return 1;
148
149 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
150 {
151 return 1;
152 }
153
154 // output and input must be the same types
155 if (inputs[0]->matchRankType(*outputs[0]))
156 {
157 printNodeValidationError("Failure to match input and output type and rank");
158 return 1;
159 }
160
Tai Lye095da72024-01-25 22:00:18 +0000161 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
162 padding = dynamic_cast<TosaReference::TensorTemplate<TPadding>*>(inputs[1]);
163 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000164 ASSERT_MEM(in && out);
165
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 return 0;
167}
168
Tai Lya4d748b2023-03-28 22:06:56 +0000169template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700170int OpPad<Rank, Dtype>::eval()
171{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000172 InEigenType pad_value = 0;
173
Tai Ly5d0e9c72024-04-05 01:19:31 +0000174 // need to use input tensor's serializationDtype to deserialize pad_const
175 // because Dtype may be FP64 in precise_mode
176 switch (DType2RefType(inputs[0]->getSerializationDtype()))
Kevin Chengcc61be32021-10-14 17:09:57 -0700177 {
Tai Ly5d0e9c72024-04-05 01:19:31 +0000178 case TOSA_REF_TYPE_BOOL: {
179 std::vector<bool> bool_data;
180 TosaSerializationHandler::ConvertU8toBool(attribute->pad_const(),
181 /* size = */ 1, bool_data);
182 pad_value = (InEigenType)bool_data[0];
183 break;
184 }
185 case TOSA_REF_TYPE_INT8: {
186 std::vector<int8_t> int8_data;
187 TosaSerializationHandler::ConvertU8toI8(attribute->pad_const(),
188 /* size = */ 1, int8_data);
189 pad_value = (InEigenType)int8_data[0];
190 break;
191 }
192 case TOSA_REF_TYPE_INT16: {
193 std::vector<int16_t> int16_data;
194 TosaSerializationHandler::ConvertU8toI16(attribute->pad_const(),
195 /* size = */ 1, int16_data);
196 pad_value = (InEigenType)int16_data[0];
197 break;
198 }
Tai Ly60dc48c2024-03-08 22:19:41 +0000199 case TOSA_REF_TYPE_INT32: {
200 std::vector<int32_t> int32_data;
201 TosaSerializationHandler::ConvertU8toI32(attribute->pad_const(),
202 /* size = */ 1, int32_data);
203 pad_value = (InEigenType)int32_data[0];
Kevin Chengfe392ce2021-10-18 21:51:55 +0000204 break;
Tai Ly60dc48c2024-03-08 22:19:41 +0000205 }
Tai Ly5d0e9c72024-04-05 01:19:31 +0000206 case TOSA_REF_TYPE_FP16: {
207 std::vector<half_float::half> f16_data;
208 TosaSerializationHandler::ConvertU8toF16(attribute->pad_const(),
209 /* size = */ 1, f16_data);
210 pad_value = (InEigenType)f16_data[0];
211 break;
212 }
213 case TOSA_REF_TYPE_BF16: {
214 std::vector<float> f32_data;
215 TosaSerializationHandler::ConvertU8toBF16(attribute->pad_const(),
216 /* size = */ 1, f32_data);
217 pad_value = (InEigenType)f32_data[0];
218 break;
219 }
220 case TOSA_REF_TYPE_FP32: {
221 std::vector<float> f32_data;
222 TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(),
223 /* size = */ 1, f32_data);
224 pad_value = (InEigenType)f32_data[0];
225 break;
226 }
227 case TOSA_REF_TYPE_FP8E4M3: {
228 std::vector<float> f32_data;
229 TosaSerializationHandler::ConvertU8toFP8E4M3(attribute->pad_const(),
230 /* size = */ 1, f32_data);
231 pad_value = (InEigenType)f32_data[0];
232 break;
233 }
Tai Ly60dc48c2024-03-08 22:19:41 +0000234 case TOSA_REF_TYPE_FP8E5M2: {
235 std::vector<float> float_data;
Tai Ly5d0e9c72024-04-05 01:19:31 +0000236 TosaSerializationHandler::ConvertU8toFP8E5M2(attribute->pad_const(),
237 /* size = */ 1, float_data);
Tai Ly60dc48c2024-03-08 22:19:41 +0000238 pad_value = (InEigenType)float_data[0];
Kevin Chengfe392ce2021-10-18 21:51:55 +0000239 break;
Tai Ly60dc48c2024-03-08 22:19:41 +0000240 }
TatWai Chong86c403b2022-06-06 20:46:01 -0700241 default:
Won Jeon3195a662024-02-27 17:52:45 +0000242 ASSERT_MSG(false, "TOSA_REF_TYPE %s is not supported.", EnumNameTOSAREFTYPE(Dtype));
TatWai Chong86c403b2022-06-06 20:46:01 -0700243 break;
Kevin Chengcc61be32021-10-14 17:09:57 -0700244 }
245
Tai Lye095da72024-01-25 22:00:18 +0000246 // padding is an 1D array of [Rank * 2], with ordering:
247 // [Rank0_front, Rank0_back, Rank1_front, Rank1_back, ..., Rank(N-1)_front, Rank(N-1)_back]
248 TPadding padding_val = this->padding->getTensor();
249 ERROR_IF(padding_val.size() != (Rank * 2), "OpPad: padding length needs to be (rank(input1) * 2)");
250 for (int i = 0; i < Rank; i++)
251 {
252 auto pad_front = padding_val(2 * i);
253 auto pad_back = padding_val(2 * i + 1);
254 ERROR_IF((pad_front < 0) || (pad_back < 0), "OpPad: padding can't be smaller than 0");
255 ERROR_IF(out->getShape()[i] != pad_front + in->getShape()[i] + pad_back,
256 "OpPad: output shape not equal to input plus padding");
257 paddings_array[i] = std::make_pair(pad_front, pad_back);
258 }
259
Eric Kunzee5e26762020-10-13 16:11:07 -0700260 this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
261
262 return GraphNode::eval();
263}
264
Won Jeona21b2e82023-08-10 10:33:01 +0000265template <int Rank, TOSA_REF_TYPE Dtype>
266OpDim<Rank, Dtype>::OpDim(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
267 : GraphNode(sgt_, Op_DIM, id_)
268{
269 setRequiredOperands(1, 1);
270
271 INIT_ATTRIBUTE(Axis);
272}
273
274template <int Rank, TOSA_REF_TYPE Dtype>
275OpDim<Rank, Dtype>::~OpDim()
276{
277 if (attribute)
278 delete attribute;
279}
280
281template <int Rank, TOSA_REF_TYPE Dtype>
282int OpDim<Rank, Dtype>::checkTensorAttributes()
283{
284 // Check Tosa Level
285 auto tosa_level = g_func_config.tosa_level;
286 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
287
288 if (validateRequiredOperands())
289 return 1;
290
291 if (validateRequiredRank(inputs[0]))
292 return 1;
293
294 if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
295 {
296 printNodeValidationError("OpDim: axis must between [0, input_rank - 1]");
297 return 1;
298 }
299
300 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
301 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
302
303 ASSERT_MEM(in && out);
304
305 return 0;
306}
307
308template <int Rank, TOSA_REF_TYPE Dtype>
309int OpDim<Rank, Dtype>::eval()
310{
311 int32_t axis = attribute->axis();
312 int64_t out_val = in->getShape()[axis];
313
Tai Ly8690a082023-12-18 20:40:24 +0000314 this->out->getTensor().setValues({ out_val });
Won Jeona21b2e82023-08-10 10:33:01 +0000315
Jerry Ge12159fc2024-04-01 17:05:10 +0000316 // set the shapeValue given the actual tensor value
317 std::vector<int> shapeValue;
318 for (int i = 0; i < out->getTensor().size(); ++i)
319 {
320 shapeValue.push_back(out->getTensor()(i));
321 }
322
323 this->getOutputs()[0]->setShapeValue(shapeValue);
324
Won Jeona21b2e82023-08-10 10:33:01 +0000325 return GraphNode::eval();
326}
327
Tai Lya4d748b2023-03-28 22:06:56 +0000328template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000329OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700330 : GraphNode(sgt_, Op_RESHAPE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700331{
Tai Ly8690a082023-12-18 20:40:24 +0000332 setRequiredOperands(2, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700333}
334
Tai Lya4d748b2023-03-28 22:06:56 +0000335template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700336OpReshape<InRank, OutRank, Dtype>::~OpReshape()
Tai Ly8690a082023-12-18 20:40:24 +0000337{}
Eric Kunzee5e26762020-10-13 16:11:07 -0700338
Tai Lya4d748b2023-03-28 22:06:56 +0000339template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700340int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
341{
Jerry Gea793f462023-04-11 00:05:02 +0000342 // Check Tosa Level
343 auto tosa_level = g_func_config.tosa_level;
344 LEVEL_CHECK(InRank <= tosa_level.MAX_RANK, "InRank should be smaller than or equal to MAX_RANK");
345 LEVEL_CHECK(OutRank <= tosa_level.MAX_RANK, "OutRank should be smaller than or equal to MAX_RANK");
346
Eric Kunzee5e26762020-10-13 16:11:07 -0700347 if (validateRequiredOperands())
348 return 1;
349
Eric Kunzee5e26762020-10-13 16:11:07 -0700350 // output and input must be the same types
351 if (inputs[0]->matchType(*outputs[0]))
352 {
353 printNodeValidationError("OpReshape: Input and output types must match");
354 return 1;
355 }
356
Kevin Chengcc61be32021-10-14 17:09:57 -0700357 ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
358 "Input tensor size does not match output tensor size");
359
Eric Kunzee5e26762020-10-13 16:11:07 -0700360 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
361 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
362
Tai Ly8690a082023-12-18 20:40:24 +0000363 // note: do not assert mem on shape input, because it may be {} for reshape to scalar
364 // and also, because the shape input is not actually used in eval()
365
366 ASSERT_MEM(in && out)
367
Eric Kunzee5e26762020-10-13 16:11:07 -0700368 return 0;
369}
370
Tai Lya4d748b2023-03-28 22:06:56 +0000371template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700372int OpReshape<InRank, OutRank, Dtype>::eval()
373{
Eric Kunzee5e26762020-10-13 16:11:07 -0700374 for (int32_t d = 0; d < OutRank; d++)
375 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000376 array_shape[d] = getOutputs()[0]->getShape()[OutRank - 1 - d];
Eric Kunzee5e26762020-10-13 16:11:07 -0700377 out_reverser[d] = OutRank - 1 - d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700378 }
379
380 for (int32_t d = 0; d < InRank; d++)
381 {
382 in_reverser[d] = InRank - 1 - d;
383 }
384
385 // Eigen Tensor is col-major, and we're referencing row-major result
386 // need to reverse it to row-major before reshape, and perform another reverse afterward
387
388 // input tensor rank 0 can't do .shuffle(), need to be handled otherwise
389 TIn in_reversed;
390 if (InRank > 1)
391 {
392 in_reversed = in->getTensor().shuffle(in_reverser);
393 }
394 else
395 {
396 in_reversed = in->getTensor();
397 }
398
399 TOut in_reshaped = in_reversed.reshape(array_shape);
400
401 // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise
402 if (OutRank > 1)
403 {
404 out->getTensor() = in_reshaped.shuffle(out_reverser);
405 }
406 else
407 {
408 out->getTensor() = in_reshaped;
409 }
410
411 return GraphNode::eval();
412}
413
Tai Lya4d748b2023-03-28 22:06:56 +0000414template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000415OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700416 : GraphNode(sgt_, Op_REVERSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700417{
418 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000419 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700420
421 INIT_ATTRIBUTE(Axis);
422}
423
Tai Lya4d748b2023-03-28 22:06:56 +0000424template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700425OpReverse<Rank, Dtype>::~OpReverse()
426{
427 if (attribute)
428 delete attribute;
429}
430
Tai Lya4d748b2023-03-28 22:06:56 +0000431template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700432int OpReverse<Rank, Dtype>::checkTensorAttributes()
433{
Jerry Gea793f462023-04-11 00:05:02 +0000434 // Check Tosa Level
435 auto tosa_level = g_func_config.tosa_level;
436 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
437
Eric Kunzee5e26762020-10-13 16:11:07 -0700438 if (validateRequiredOperands())
439 return 1;
440
441 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
442 {
443 return 1;
444 }
445
446 // output and input must be the same types
447 if (inputs[0]->matchRankTypeShape(*outputs[0]))
448 {
449 printNodeValidationError("Failure to match input and output rank/type/shape");
450 return 1;
451 }
452
453 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
454 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
455
456 ASSERT_MEM(in && out);
457
458 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
459 {
460 printNodeValidationError("Reverse axis must between [0, input_rank - 1]");
461 return 1;
462 }
463
464 // transform list of axis into true or false list
465 // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false]
466 for (int i = 0; i < Rank; i++)
467 {
468 reverse_array[i] = false;
469 }
470 reverse_array[attribute->axis()] = true;
471
472 return 0;
473}
474
Tai Lya4d748b2023-03-28 22:06:56 +0000475template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700476int OpReverse<Rank, Dtype>::eval()
477{
478 out->getTensor() = in->getTensor().reverse(reverse_array);
479
480 return GraphNode::eval();
481}
482
Tai Lya4d748b2023-03-28 22:06:56 +0000483template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000484OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700485 : GraphNode(sgt_, Op_SLICE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700486{
TatWai Chong01f937a2024-01-24 22:57:07 -0800487 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000488 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700489}
490
Tai Lya4d748b2023-03-28 22:06:56 +0000491template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700492OpSlice<Rank, Dtype>::~OpSlice()
TatWai Chong01f937a2024-01-24 22:57:07 -0800493{}
Eric Kunzee5e26762020-10-13 16:11:07 -0700494
Tai Lya4d748b2023-03-28 22:06:56 +0000495template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700496int OpSlice<Rank, Dtype>::checkTensorAttributes()
497{
Jerry Gea793f462023-04-11 00:05:02 +0000498 // Check Tosa Level
499 auto tosa_level = g_func_config.tosa_level;
500 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
501
Eric Kunzee5e26762020-10-13 16:11:07 -0700502 if (validateRequiredOperands())
503 return 1;
504
505 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
506 {
507 return 1;
508 }
509
510 // output and input must be the same types
Luke Huttona4e48ca2023-02-22 11:53:48 +0000511 if (inputs[0]->matchRankType(*outputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700512 {
Luke Huttona4e48ca2023-02-22 11:53:48 +0000513 printNodeValidationError("Failure to match input and output rank or type");
Eric Kunzee5e26762020-10-13 16:11:07 -0700514 return 1;
515 }
516
TatWai Chong01f937a2024-01-24 22:57:07 -0800517 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
TatWai Chong97f1c0e2024-02-05 11:56:46 -0800518 start = dynamic_cast<TosaReference::TensorTemplate<TSlicing>*>(inputs[1]);
519 size = dynamic_cast<TosaReference::TensorTemplate<TSlicing>*>(inputs[2]);
TatWai Chong01f937a2024-01-24 22:57:07 -0800520 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700521
TatWai Chong97f1c0e2024-02-05 11:56:46 -0800522 ASSERT_MEM(in && out && start && size);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000523
TatWai Chong01f937a2024-01-24 22:57:07 -0800524 return 0;
525}
526
527template <int Rank, TOSA_REF_TYPE Dtype>
528int OpSlice<Rank, Dtype>::eval()
529{
TatWai Chong97f1c0e2024-02-05 11:56:46 -0800530 TSlicing start_tensor = start->getTensor();
531 TSlicing size_tensor = size->getTensor();
532
533 // According to https://eigen.tuxfamily.org/dox/unsupported/eigen_tensors.html
534 // The type of size() is <Tensor-Type>::Index, but can always handily use it like an int.
535 // However, apply explicit cast to int32_t is preferred.
536 ERROR_IF(static_cast<int32_t>(start_tensor.size()) != in->getRank(),
537 "OpSlice: start array length needs to be rank(input)");
538 ERROR_IF(static_cast<int32_t>(size_tensor.size()) != in->getRank(),
539 "OpSlice: size array length needs to be rank(input)");
Eric Kunzee5e26762020-10-13 16:11:07 -0700540
Kevin Chengcc61be32021-10-14 17:09:57 -0700541 for (int32_t i = 0; i < in->getRank(); i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700542 {
TatWai Chong97f1c0e2024-02-05 11:56:46 -0800543 int32_t b = start_tensor(i);
544 int32_t s = size_tensor(i);
Kevin Chengcc61be32021-10-14 17:09:57 -0700545 ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary");
546 ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary");
547 ERROR_IF(s <= 0, "OpSlice: output must be positive");
548 ERROR_IF(s != out->getShape()[i], "OpSlice: size doesn't match output tensor dimension");
549 begin_array[i] = b;
550 size_array[i] = s;
Eric Kunzee5e26762020-10-13 16:11:07 -0700551 }
552
Eric Kunzee5e26762020-10-13 16:11:07 -0700553 out->getTensor() = in->getTensor().slice(begin_array, size_array);
554
555 return GraphNode::eval();
556}
557
Tai Lya4d748b2023-03-28 22:06:56 +0000558template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000559OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700560 : GraphNode(sgt_, Op_TILE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700561{
Tai Ly8690a082023-12-18 20:40:24 +0000562 setRequiredOperands(2, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000563 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700564}
565
Tai Lya4d748b2023-03-28 22:06:56 +0000566template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700567OpTileBase<Rank, Dtype>::~OpTileBase()
Tai Ly8690a082023-12-18 20:40:24 +0000568{}
Eric Kunzee5e26762020-10-13 16:11:07 -0700569
Tai Lya4d748b2023-03-28 22:06:56 +0000570template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700571int OpTileBase<Rank, Dtype>::checkTensorAttributes()
572{
Jerry Gea793f462023-04-11 00:05:02 +0000573 // Check Tosa Level
574 auto tosa_level = g_func_config.tosa_level;
575 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
576
Eric Kunzee5e26762020-10-13 16:11:07 -0700577 if (validateRequiredOperands())
578 return 1;
579
580 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
581 {
582 return 1;
583 }
584
585 // output and input must be the same ranks and types
586 if (inputs[0]->matchRankType(*outputs[0]))
587 {
588 printNodeValidationError("Failure to match input and output rank or type");
589 return 1;
590 }
591
Tai Ly8690a082023-12-18 20:40:24 +0000592 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
593 multiples = dynamic_cast<TosaReference::TensorTemplate<TInMultiples>*>(inputs[1]);
594 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700595
Tai Ly8690a082023-12-18 20:40:24 +0000596 ASSERT_MEM(in && multiples && out);
Luke Huttona4e48ca2023-02-22 11:53:48 +0000597
Tai Ly8690a082023-12-18 20:40:24 +0000598 if (multiples->getElementCount() != Rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700599 {
600 printNodeValidationError("1D list 'multiples' must have size equal to input rank");
601 return 1;
602 }
603
Eric Kunzee5e26762020-10-13 16:11:07 -0700604 return 0;
605}
606
Tai Lya4d748b2023-03-28 22:06:56 +0000607template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700608int OpTile<Rank, Dtype>::eval()
609{
610 // primary template shouldn't be called
Tai Lya4d748b2023-03-28 22:06:56 +0000611 FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNameTOSAREFTYPE(Dtype));
Eric Kunzee5e26762020-10-13 16:11:07 -0700612}
613
Tai Lya4d748b2023-03-28 22:06:56 +0000614template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700615int OpTile<1, Dtype>::eval()
616{
617 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
618 {
619 int32_t id0 = od0 % this->in->getShape()[0];
620 this->out->getTensor()(od0) = this->in->getTensor()(id0);
621 }
622
623 return GraphNode::eval();
624}
625
Tai Lya4d748b2023-03-28 22:06:56 +0000626template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700627int OpTile<2, Dtype>::eval()
628{
629 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
630 {
631 int32_t id0 = od0 % this->in->getShape()[0];
632 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
633 {
634 int32_t id1 = od1 % this->in->getShape()[1];
635 this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1);
636 }
637 }
638
639 return GraphNode::eval();
640}
641
Tai Lya4d748b2023-03-28 22:06:56 +0000642template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700643int OpTile<3, Dtype>::eval()
644{
645 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
646 {
647 int32_t id0 = od0 % this->in->getShape()[0];
648 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
649 {
650 int32_t id1 = od1 % this->in->getShape()[1];
651 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
652 {
653 int32_t id2 = od2 % this->in->getShape()[2];
654 this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2);
655 }
656 }
657 }
658
659 return GraphNode::eval();
660}
661
Tai Lya4d748b2023-03-28 22:06:56 +0000662template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700663int OpTile<4, Dtype>::eval()
664{
665 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
666 {
667 int32_t id0 = od0 % this->in->getShape()[0];
668 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
669 {
670 int32_t id1 = od1 % this->in->getShape()[1];
671 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
672 {
673 int32_t id2 = od2 % this->in->getShape()[2];
674 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
675 {
676 int32_t id3 = od3 % this->in->getShape()[3];
677 this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3);
678 }
679 }
680 }
681 }
682
683 return GraphNode::eval();
684}
685
Tai Lya4d748b2023-03-28 22:06:56 +0000686template <TOSA_REF_TYPE Dtype>
Luke Huttona4e48ca2023-02-22 11:53:48 +0000687int OpTile<5, Dtype>::eval()
688{
689 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
690 {
691 int32_t id0 = od0 % this->in->getShape()[0];
692 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
693 {
694 int32_t id1 = od1 % this->in->getShape()[1];
695 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
696 {
697 int32_t id2 = od2 % this->in->getShape()[2];
698 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
699 {
700 int32_t id3 = od3 % this->in->getShape()[3];
701 for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
702 {
703 int32_t id4 = od4 % this->in->getShape()[4];
704 this->out->getTensor()(od0, od1, od2, od3, od4) =
705 this->in->getTensor()(id0, id1, id2, id3, id4);
706 }
707 }
708 }
709 }
710 }
711
712 return GraphNode::eval();
713}
714
Tai Lya4d748b2023-03-28 22:06:56 +0000715template <TOSA_REF_TYPE Dtype>
Luke Huttona4e48ca2023-02-22 11:53:48 +0000716int OpTile<6, Dtype>::eval()
717{
718 for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
719 {
720 int32_t id0 = od0 % this->in->getShape()[0];
721 for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
722 {
723 int32_t id1 = od1 % this->in->getShape()[1];
724 for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
725 {
726 int32_t id2 = od2 % this->in->getShape()[2];
727 for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
728 {
729 int32_t id3 = od3 % this->in->getShape()[3];
730 for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
731 {
732 int32_t id4 = od4 % this->in->getShape()[4];
733 for (int32_t od5 = 0; od5 < this->out->getShape()[5]; od5++)
734 {
735 int32_t id5 = od5 % this->in->getShape()[5];
736 this->out->getTensor()(od0, od1, od2, od3, od4, od5) =
737 this->in->getTensor()(id0, id1, id2, id3, id4, id5);
738 }
739 }
740 }
741 }
742 }
743 }
744
745 return GraphNode::eval();
746}
747
Tai Lya4d748b2023-03-28 22:06:56 +0000748template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000749OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700750 : GraphNode(sgt_, Op_TRANSPOSE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700751{
Kevin Chengfe392ce2021-10-18 21:51:55 +0000752 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000753 setRequiredRank(1);
Kevin Chengfe392ce2021-10-18 21:51:55 +0000754
755 INIT_ATTRIBUTE(Transpose);
Eric Kunzee5e26762020-10-13 16:11:07 -0700756}
757
Tai Lya4d748b2023-03-28 22:06:56 +0000758template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700759OpTranspose<Rank, Dtype>::~OpTranspose()
Jerry Gea6827492022-11-16 10:41:55 -0800760{
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000761 if (attribute)
762 delete attribute;
Jerry Gea6827492022-11-16 10:41:55 -0800763}
Eric Kunzee5e26762020-10-13 16:11:07 -0700764
Tai Lya4d748b2023-03-28 22:06:56 +0000765template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700766int OpTranspose<Rank, Dtype>::checkTensorAttributes()
767{
Jerry Gea793f462023-04-11 00:05:02 +0000768 // Check Tosa Level
769 auto tosa_level = g_func_config.tosa_level;
770 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
771
Eric Kunzee5e26762020-10-13 16:11:07 -0700772 if (validateRequiredOperands())
773 return 1;
774
775 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
776 {
777 return 1;
778 }
779
780 // output and input must be the same types
781 if (inputs[0]->matchRankType(*outputs[0]))
782 {
783 printNodeValidationError("Failure to match input and output rank and type");
784 return 1;
785 }
786
787 if (inputs[0]->getElementCount() != outputs[0]->getElementCount())
788 {
789 printNodeValidationError("Failure to match input and output total element count");
790 return 1;
791 }
792
Kevin Chengfe392ce2021-10-18 21:51:55 +0000793 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
794 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
795
796 ASSERT_MEM(in && out);
Eric Kunzee5e26762020-10-13 16:11:07 -0700797
TatWai Chong86c403b2022-06-06 20:46:01 -0700798 ERROR_IF(attribute->perms().size() != Rank, "OpTranspose: perms array size needs to match rank(input)");
Kevin Chengf3e016f2021-11-02 01:15:50 +0000799
800 std::array<bool, Rank> index_used;
801 index_used.fill(false);
802 for (int32_t d = 0; d < Rank; d++)
803 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700804 int32_t index = attribute->perms()[d];
Kevin Chengf3e016f2021-11-02 01:15:50 +0000805 ERROR_IF(index < 0 or index >= Rank, "OpTranspose: index out of boundary");
806 ERROR_IF(index_used[index], "OpTranspose: index duplicated in perm attribute");
807 index_used[index] = true;
808 ERROR_IF(in->getShape()[index] != out->getShape()[d], "OpTranspose: input output shape mismatch");
809 perm_array[d] = index;
810 }
811
Eric Kunzee5e26762020-10-13 16:11:07 -0700812 return 0;
813}
814
Tai Lya4d748b2023-03-28 22:06:56 +0000815template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700816int OpTranspose<Rank, Dtype>::eval()
817{
Eric Kunzee5e26762020-10-13 16:11:07 -0700818 out->getTensor() = in->getTensor().shuffle(perm_array);
819
820 return GraphNode::eval();
821}
822
823// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +0100824DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16)
James Ward24dbc422022-10-19 12:20:31 +0100825DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100826DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700827DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
828DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
829DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
830DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
Tai Lya4d748b2023-03-28 22:06:56 +0000831DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64)
Won Jeon2c34b462024-02-06 18:37:00 +0000832DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3);
833DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700834
James Ward8b390432022-08-12 20:48:56 +0100835DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100836DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100837DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700838DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
839DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
840DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
841DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000842DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000843DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3);
844DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700845
Won Jeona21b2e82023-08-10 10:33:01 +0000846DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16);
847DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BF16);
848DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP32);
849DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8);
850DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16);
851DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT32);
852DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL);
Won Jeon2c34b462024-02-06 18:37:00 +0000853DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3);
854DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2);
Won Jeona21b2e82023-08-10 10:33:01 +0000855
James Ward8b390432022-08-12 20:48:56 +0100856DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100857DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100858DEF_INSTANTIATE_RESHAPE(OpReshape, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700859DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
860DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
861DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
862DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000863DEF_INSTANTIATE_RESHAPE(OpReshape, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000864DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E4M3);
865DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700866
James Ward8b390432022-08-12 20:48:56 +0100867DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100868DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100869DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700870DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
871DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
872DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
873DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000874DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000875DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3);
876DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700877
Luke Huttona4e48ca2023-02-22 11:53:48 +0000878DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
879DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
880DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
881DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
882DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
883DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
884DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000885DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000886DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3);
887DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700888
Luke Huttona4e48ca2023-02-22 11:53:48 +0000889DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
890DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
891DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32);
892DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
893DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
894DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
895DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000896DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000897DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E4M3);
898DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E5M2);
Jared Smolens98c281f2022-12-20 15:09:25 -0800899
Luke Huttona4e48ca2023-02-22 11:53:48 +0000900DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
901DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
902DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
903DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
904DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
905DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
906DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000907DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000908DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3);
909DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700910
Luke Huttona4e48ca2023-02-22 11:53:48 +0000911DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
912DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
913DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
914DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
915DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
916DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
917DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000918DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000919DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3);
920DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2);