blob: f38f48676b129846a3af8e56e42987ac78970145 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tensor_ops.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000017#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070018#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
Kevin Cheng9fe17242021-11-10 01:04:39 +000025int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
26 std::vector<int32_t> input_shape,
27 std::vector<int32_t> output_shape,
28 std::string& msg)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000029{
TatWai Chong86c403b2022-06-06 20:46:01 -070030 if (attribute->pad().size() != 4)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000031 {
32 msg = "illegal size for attribute padding";
33 return 1;
34 }
35
36 if (attribute->kernel().size() != 2)
37 {
38 msg = "illegal size for attribute kernel";
39 return 1;
40 }
41
42 if (attribute->stride().size() != 2)
43 {
44 msg = "illegal size for attribute stride";
45 return 1;
46 }
47
TatWai Chong86c403b2022-06-06 20:46:01 -070048 for (int32_t i : attribute->pad())
Kevin Cheng7eb93d72021-10-09 01:26:08 +000049 {
50 if (i < 0)
51 {
52 msg = "At least one pad is smaller than zero";
53 return 1;
54 }
55 }
56
57 for (int32_t i : attribute->kernel())
58 {
59 if (i < 1)
60 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000061 msg = "At least one kernel dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000062 return 1;
63 }
64 }
65
66 for (int32_t i : attribute->stride())
67 {
68 if (i < 1)
69 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000070 msg = "At least one stride dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000071 return 1;
72 }
73 }
74
75 int32_t IH = input_shape[1];
76 int32_t IW = input_shape[2];
77 int32_t OH = output_shape[1];
78 int32_t OW = output_shape[2];
79
TatWai Chong86c403b2022-06-06 20:46:01 -070080 int32_t pad_top = attribute->pad()[0];
81 int32_t pad_bottom = attribute->pad()[1];
82 int32_t pad_left = attribute->pad()[2];
83 int32_t pad_right = attribute->pad()[3];
Kevin Cheng7eb93d72021-10-09 01:26:08 +000084
85 int32_t stride_y = attribute->stride()[0];
86 int32_t stride_x = attribute->stride()[1];
87 int32_t kernel_y = attribute->kernel()[0];
88 int32_t kernel_x = attribute->kernel()[1];
89
90 if (pad_top >= kernel_y || pad_bottom >= kernel_y || pad_left >= kernel_x || pad_right >= kernel_x)
91 {
92 msg = "At least one pad is >= kernel dimension";
93 return 1;
94 }
95
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010096 int32_t full_H = IH + pad_top + pad_bottom - kernel_y;
97 int32_t full_W = IW + pad_left + pad_right - kernel_x;
98
Jerry Ge9c9c8da2023-07-19 23:08:16 +000099 if ((full_H % stride_y != 0) || (full_W % stride_x != 0))
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000100 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100101 msg = "Parameters must yield exact integer output dimensions";
102 return 1;
103 }
104
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000105 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100106 {
107 msg = "Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000108 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000109 return 1;
110 }
111
112 return 0;
113}
114
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000115int check_conv_attribute(tosa::TosaConvAttribute* attribute,
Tai Lya4d748b2023-03-28 22:06:56 +0000116 uint32_t conv_dimension,
117 std::vector<int32_t> input_shape,
118 std::vector<int32_t> output_shape,
119 std::vector<int32_t> weights,
120 uint32_t offset_kernel,
121 TOSA_REF_TYPE InDtype,
122 TOSA_REF_TYPE WeightDtype,
123 std::string& msg)
Kevin Cheng9fe17242021-11-10 01:04:39 +0000124{
TatWai Chong86c403b2022-06-06 20:46:01 -0700125 if (attribute->pad().size() != (2 * conv_dimension))
Kevin Cheng9fe17242021-11-10 01:04:39 +0000126 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700127 msg = "Illegal size for attribute pad";
Kevin Cheng9fe17242021-11-10 01:04:39 +0000128 return 1;
129 }
130
131 if (attribute->stride().size() != conv_dimension)
132 {
133 msg = "Illegal size for attribute stride";
134 return 1;
135 }
136
137 if (attribute->dilation().size() != conv_dimension)
138 {
139 msg = "Illegal size for attribute dilation";
140 return 1;
141 }
142
TatWai Chong86c403b2022-06-06 20:46:01 -0700143 for (int32_t i : attribute->pad())
Kevin Cheng9fe17242021-11-10 01:04:39 +0000144 {
145 if (i < 0)
146 {
147 msg = "At least one pad is smaller than zero";
148 return 1;
149 }
150 }
151
152 for (int32_t i : attribute->stride())
153 {
154 if (i < 1)
155 {
156 msg = "At least one stride dimension is smaller than one";
157 return 1;
158 }
159 }
160
161 for (int32_t i : attribute->dilation())
162 {
163 if (i < 1)
164 {
165 msg = "At least one dilation dimension is smaller than one";
166 return 1;
167 }
168 }
169
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100170 ASSERT_MSG(conv_dimension == 2 || conv_dimension == 3, "Unsupported convolution dimension")
171
TatWai Chongfd629052022-07-25 04:01:58 +0000172 int32_t offset_d = conv_dimension == 3 ? 1 : 0;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000173 int32_t ID = conv_dimension == 3 ? input_shape[1] : 1;
174 int32_t IH = input_shape[1 + offset_d];
175 int32_t IW = input_shape[2 + offset_d];
176 int32_t OD = conv_dimension == 3 ? output_shape[1] : 1;
177 int32_t OH = output_shape[1 + offset_d];
178 int32_t OW = output_shape[2 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100179
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000180 int32_t stride_d = conv_dimension == 3 ? attribute->stride()[0] : 1;
181 int32_t stride_y = attribute->stride()[0 + offset_d];
182 int32_t stride_x = attribute->stride()[1 + offset_d];
183 int32_t kernel_d = conv_dimension == 3 ? weights[offset_kernel] : 1;
184 int32_t kernel_h = weights[offset_kernel + offset_d];
185 int32_t kernel_w = weights[offset_kernel + 1 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100186 int32_t dilation_d = conv_dimension == 3 ? attribute->dilation()[0] : 1;
187 int32_t dilation_y = attribute->dilation()[0 + offset_d];
188 int32_t dilation_x = attribute->dilation()[1 + offset_d];
189
190 offset_d *= 2;
TatWai Chong86c403b2022-06-06 20:46:01 -0700191 int32_t pad_d0 = conv_dimension == 3 ? attribute->pad()[0] : 0;
192 int32_t pad_d1 = conv_dimension == 3 ? attribute->pad()[1] : 0;
193 int32_t pad_top = attribute->pad()[0 + offset_d];
194 int32_t pad_bottom = attribute->pad()[1 + offset_d];
195 int32_t pad_left = attribute->pad()[2 + offset_d];
196 int32_t pad_right = attribute->pad()[3 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100197
198 int32_t full_D = ID - 1 + pad_d0 + pad_d1 - (kernel_d - 1) * dilation_d;
199 int32_t full_H = IH - 1 + pad_top + pad_bottom - (kernel_h - 1) * dilation_y;
200 int32_t full_W = IW - 1 + pad_left + pad_right - (kernel_w - 1) * dilation_x;
201
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000202 if ((full_H % stride_y != 0) || (full_W % stride_x != 0) || (full_D % stride_d != 0))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100203 {
204 msg = "Parameters must yield exact integer output dimensions";
205 return 1;
206 }
207
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000208 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1) || (OD != (full_D / stride_d) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100209 {
210 std::string msg_d = "";
211 if (conv_dimension == 3)
212 {
213 msg_d += std::to_string((full_D / stride_d) + 1) + ",";
214 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000215 msg = "Mismatch between output shape provided and expected output shape (" + msg_d +
216 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100217 return 1;
218 }
219
Tai Lya4d748b2023-03-28 22:06:56 +0000220 if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0)
221 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000222 msg = "Input zero point must be zero for non-int8 data";
223 return 1;
224 }
Tai Lya4d748b2023-03-28 22:06:56 +0000225 if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0)
226 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 msg = "Weight zero point must be zero for non-int8 data";
228 return 1;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000229 }
230
231 return 0;
232}
233
Luke Hutton57287132023-02-06 14:54:18 +0000234int check_fft_shape(const std::vector<int32_t>& in_real,
235 const std::vector<int32_t>& in_imag,
236 const std::vector<int32_t>& out_real,
237 const std::vector<int32_t>& out_imag,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000238 std::string& msg)
239{
240 const bool is_rfft = in_imag.empty();
241 auto is_power_of_two = [](int32_t n) -> bool { return (n & (n - 1)) == 0 && n > 0; };
Luke Hutton57287132023-02-06 14:54:18 +0000242
243 if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2]))
244 {
245 msg = "Input height and width must be a power of two";
246 return 1;
247 }
248
249 // RFFT does not have a second input
250 if (!is_rfft)
251 {
252 bool input_check = true;
253 for (size_t i = 0; i < in_real.size(); i++)
254 {
255 if (in_real[i] != in_imag[i])
256 {
257 input_check = false;
258 break;
259 }
260 }
261 if (!input_check)
262 {
263 msg = "Mismatch between real input shape and imaginary input shape";
264 return 1;
265 }
266 }
267
268 bool output_check = true;
269 for (size_t i = 0; i < out_real.size(); i++)
270 {
271 if (out_real[i] != out_imag[i])
272 {
273 output_check = false;
274 break;
275 }
276 }
277 if (!output_check)
278 {
279 msg = "Mismatch between real output shape and imaginary output shape";
280 return 1;
281 }
282
283 if (in_real[0] != out_real[0])
284 {
285 msg = "Input and output batch size don't match";
286 return 1;
287 }
288 if (in_real[1] != out_real[1])
289 {
290 msg = "Input and output height don't match";
291 return 1;
292 }
293
294 if (is_rfft)
295 {
296 if (in_real[2] / 2 + 1 != out_real[2])
297 {
298 msg = "Output width is expected to match input width / 2 + 1";
299 return 1;
300 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000301 }
302 else
303 {
Luke Hutton57287132023-02-06 14:54:18 +0000304 if (in_real[2] != out_real[2])
305 {
306 msg = "Input and output width don't match";
307 return 1;
308 }
309 }
310
311 return 0;
312}
313
Tai Lya4d748b2023-03-28 22:06:56 +0000314template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000315OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700316 : GraphNode(sgt_, Op_ARGMAX, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700317{
318 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000319 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
321 INIT_ATTRIBUTE(Axis);
322}
323
Tai Lya4d748b2023-03-28 22:06:56 +0000324template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700325OpArgMax<Rank, Dtype>::~OpArgMax()
326{
327 if (attribute)
328 delete attribute;
329}
330
Tai Lya4d748b2023-03-28 22:06:56 +0000331template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700332int OpArgMax<Rank, Dtype>::checkTensorAttributes()
333{
334 if (validateRequiredOperands())
335 return 1;
336
Kevin Chengcc61be32021-10-14 17:09:57 -0700337 if (validateRequiredRank(inputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700338 {
339 return 1;
340 }
341
Kevin Chengcc61be32021-10-14 17:09:57 -0700342 int32_t output_rank = inputs[0]->getRank() - 1;
343 if (output_rank != outputs[0]->getRank())
344 {
345 printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
346 return 1;
347 }
348
Tai Lya4d748b2023-03-28 22:06:56 +0000349 if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32)
Kevin Chengcc61be32021-10-14 17:09:57 -0700350 {
351 printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
352 return 1;
353 }
354
Eric Kunzee5e26762020-10-13 16:11:07 -0700355 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
356 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
357
Kevin Chengcc61be32021-10-14 17:09:57 -0700358 if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
359 {
360 printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
361 return 1;
362 }
363
364 bool shape_check = true;
365 for (int32_t i = 0; i < input->getRank(); i++)
366 {
367 if (i < attribute->axis())
368 {
369 if (input->getShape()[i] != output->getShape()[i])
370 {
371 shape_check = false;
372 break;
373 }
374 }
375 else if (i > attribute->axis())
376 {
377 if (input->getShape()[i] != output->getShape()[i - 1])
378 {
379 shape_check = false;
380 break;
381 }
382 }
383 // No need to check i == axis
384 }
385 if (!shape_check)
386 {
387 printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
388 return 1;
389 }
390
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 return 0;
392}
393
Tai Lya4d748b2023-03-28 22:06:56 +0000394template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700395int OpArgMax<Rank, Dtype>::eval()
396{
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000397 // Check Tosa Level
398 auto tosa_level = g_func_config.tosa_level;
399 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
400
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
402
403 this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
404
405 return GraphNode::eval();
406}
407
Tai Lya4d748b2023-03-28 22:06:56 +0000408template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000409OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700410 : GraphNode(sgt_, Op_AVG_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700411{
412 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000413 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700414
Kevin Cheng93a16282021-08-31 16:14:03 -0700415 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -0700416}
417
Tai Lya4d748b2023-03-28 22:06:56 +0000418template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100419OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700420{
421 if (attribute)
422 delete attribute;
423}
424
Tai Lya4d748b2023-03-28 22:06:56 +0000425template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100426int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700427{
428 if (validateRequiredOperands())
429 return 1;
430
431 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
432 {
433 return 1;
434 }
435
436 if (inputs[0]->matchType(*outputs[0]))
437 {
438 printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
439 return 1;
440 }
441
442 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
443 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
444
Tai Lya4d748b2023-03-28 22:06:56 +0000445 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
446 "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
447 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0,
448 "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
Eric Kunzee5e26762020-10-13 16:11:07 -0700449
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000450 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000451 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700452 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000453 msg = "OpAvgPool2d: " + msg;
454 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return 1;
456 }
457
458 return 0;
459}
460
Eric Kunzee5e26762020-10-13 16:11:07 -0700461// assuming input and output tensor have same scales like tflite reference
462// so no need to scale input and output
Tai Lya4d748b2023-03-28 22:06:56 +0000463template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100464int OpAvgPool2d<Dtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700465{
466 int in_batch = this->in->getShape()[0];
467 int in_height = this->in->getShape()[1];
468 int in_width = this->in->getShape()[2];
469 int in_channels = this->in->getShape()[3];
470
471 int out_batch = this->out->getShape()[0];
472 int out_height = this->out->getShape()[1];
473 int out_width = this->out->getShape()[2];
474 int out_channels = this->out->getShape()[3];
475
Kevin Chengacb550f2021-06-29 15:32:19 -0700476 ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
477 ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700478
TatWai Chong86c403b2022-06-06 20:46:01 -0700479 int pad_top = this->attribute->pad()[0];
480 int pad_bottom = this->attribute->pad()[1];
481 int pad_left = this->attribute->pad()[2];
482 int pad_right = this->attribute->pad()[3];
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000483 int kernel_y = this->attribute->kernel()[0];
484 int kernel_x = this->attribute->kernel()[1];
485 int stride_y = this->attribute->stride()[0];
486 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000487
488 // Check Tosa Level
489 auto tosa_level = g_func_config.tosa_level;
490 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
491 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
492 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
493 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
494 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
495 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
496 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
497 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -0700498
Tai Ly8ead6c42024-02-14 22:35:44 +0000499 TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->acc_type());
James Ward8b390432022-08-12 20:48:56 +0100500
Eric Kunzee5e26762020-10-13 16:11:07 -0700501 DEBUG_INFO(OP,
502 "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
Tai Ly8ead6c42024-02-14 22:35:44 +0000503 "stride=[%d,%d], pad=[%d,%d,%d,%d], acc_type=%s",
Jerry Gea793f462023-04-11 00:05:02 +0000504 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
505 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700506
TatWai Chong86c403b2022-06-06 20:46:01 -0700507 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
508 pad[0] = std::make_pair(0, 0);
509 pad[1] = std::make_pair(pad_top, pad_bottom);
510 pad[2] = std::make_pair(pad_left, pad_right);
511 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700512
513 ETensor4<InEigenType> input_val = this->in->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000514 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700515 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000516 input_val = input_val - (InEigenType)attribute->input_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700517 }
518
Tai Ly307392a2023-05-12 21:42:19 +0000519 if (g_func_config.abs_mode)
520 {
James Ward29e46cf2023-10-23 11:47:25 +0000521 // in abs_mode: take abs values of input_val
522 input_val = input_val.abs();
Tai Ly307392a2023-05-12 21:42:19 +0000523 }
524
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 // assuming input and output have same scales
526 // so input and output scaling is not required
527 // TODO: check if this assumption TOSA made
528
James Ward5a9e0cd2023-10-09 16:51:26 +0000529 ETensor4<OutEigenType> out_tens(out_batch, out_height, out_width, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700530
531 // sum pool
James Ward5a9e0cd2023-10-09 16:51:26 +0000532 for (int ob = 0; ob < out_batch; ++ob)
Eric Kunzee5e26762020-10-13 16:11:07 -0700533 {
James Ward5a9e0cd2023-10-09 16:51:26 +0000534 for (int oh = 0; oh < out_height; ++oh)
Eric Kunzee5e26762020-10-13 16:11:07 -0700535 {
James Ward5a9e0cd2023-10-09 16:51:26 +0000536 for (int ow = 0; ow < out_width; ++ow)
537 {
538 for (int oc = 0; oc < out_channels; ++oc)
539 {
540 AccEigenType acc(0);
541 int filter_count = 0;
542 const int iy = oh * stride_y - pad_top;
543 const int ix = ow * stride_x - pad_left;
544 for (int ky = 0; ky < kernel_y; ++ky)
545 {
546 for (int kx = 0; kx < kernel_x; ++kx)
547 {
548 const int y = iy + ky;
549 const int x = ix + kx;
550 if ((0 <= y && y < in_height) && (0 <= x && x < in_width))
551 {
552 ++filter_count;
James Ward29e46cf2023-10-23 11:47:25 +0000553 acc = acc + (AccEigenType)input_val(ob, y, x, oc);
James Ward5a9e0cd2023-10-09 16:51:26 +0000554 }
555 }
556 }
557 if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
Won Jeon2c34b462024-02-06 18:37:00 +0000558 Dtype != TOSA_REF_TYPE_FP64 && Dtype != TOSA_REF_TYPE_FP8E4M3 && Dtype != TOSA_REF_TYPE_FP8E5M2)
James Ward5a9e0cd2023-10-09 16:51:26 +0000559 {
560 try
561 {
562 int32_t multiplier, shift;
563 OutEigenType out;
564 TosaReference::QuantUtil::reciprocal_scale(filter_count, multiplier, shift);
565
566 out = (OutEigenType)TosaReference::QuantUtil::apply_scale_32(acc, multiplier, shift, false);
567 out = out + (OutEigenType)(attribute->output_zp());
568 out = std::max(out, (OutEigenType)QMin);
569 out_tens(ob, oh, ow, oc) = std::min(out, (OutEigenType)QMax);
570 }
571 catch (std::string desc)
572 {
573 REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
574 }
575 }
576 else
577 {
578 REQUIRE(filter_count != 0, "OpAvgPool2d number of filters should be non-zero.");
579 out_tens(ob, oh, ow, oc) = acc / static_cast<OutEigenType>(filter_count);
580 }
581 }
582 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700583 }
584 }
James Ward5a9e0cd2023-10-09 16:51:26 +0000585 this->out->getTensor() = out_tens;
Eric Kunzee5e26762020-10-13 16:11:07 -0700586 return GraphNode::eval();
587}
588
Tai Lyf36f2562024-03-14 16:21:29 +0000589template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
590OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
591 TosaAttributeBase* attribute_,
592 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700593 : GraphNode(sgt_, Op_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700594{
595 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000596 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700597
Kevin Cheng93a16282021-08-31 16:14:03 -0700598 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700599}
600
Tai Lyf36f2562024-03-14 16:21:29 +0000601template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
602OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700603{
604 if (attribute)
605 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700606}
607
Tai Lyf36f2562024-03-14 16:21:29 +0000608template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
609int OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700610{
611 if (validateRequiredOperands())
612 return 1;
613
614 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
615 {
616 return 1;
617 }
618
619 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
620 if (inputs[2]->getRank() != 1)
621 {
622 printNodeValidationError("OpConv2d: bias tensor must be rank 1");
623 }
624
James Wardd34b3fc2023-01-18 14:51:25 +0000625 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000626 "OpConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700627
Eric Kunzee5e26762020-10-13 16:11:07 -0700628 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
629 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
630 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100631 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700632
Kevin Cheng9fe17242021-11-10 01:04:39 +0000633 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000634 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000635 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700636 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000637 msg = "OpConv2d: " + msg;
638 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700639 return 1;
640 }
641
Eric Kunzee5e26762020-10-13 16:11:07 -0700642 return 0;
643}
644
Tai Lyf36f2562024-03-14 16:21:29 +0000645template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
646int OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700647{
648 int in_batch = this->input->getShape()[0];
649 int in_height = this->input->getShape()[1];
650 int in_width = this->input->getShape()[2];
651 int in_channels = this->input->getShape()[3];
652
653 int f_out_channels = this->weight->getShape()[0];
654 int f_height = this->weight->getShape()[1];
655 int f_width = this->weight->getShape()[2];
656 int f_in_channels = this->weight->getShape()[3];
657
658 int b_out_channels = this->bias->getShape()[0];
659
660 int out_batch = this->output->getShape()[0];
661 int out_height = this->output->getShape()[1];
662 int out_width = this->output->getShape()[2];
663 int out_channels = this->output->getShape()[3];
664
Kevin Chengacb550f2021-06-29 15:32:19 -0700665 ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
666 ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
667 in_channels);
668 ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
669 out_channels);
Tai Lya641dd52023-08-11 19:58:50 +0000670 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv2d: bias channel mismatch %d != %d",
671 b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700672
TatWai Chong86c403b2022-06-06 20:46:01 -0700673 int pad_top = this->attribute->pad()[0];
674 int pad_bottom = this->attribute->pad()[1];
675 int pad_left = this->attribute->pad()[2];
676 int pad_right = this->attribute->pad()[3];
677
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000678 int stride_y = this->attribute->stride()[0];
679 int stride_x = this->attribute->stride()[1];
680 int dilation_y = this->attribute->dilation()[0];
681 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000682
683 // Check Tosa Level
684 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000685 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
686 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
687 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
688 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000689 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
690 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
691 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
692 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
693 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
694 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -0700695
696 DEBUG_INFO(OP,
697 "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000698 "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -0700699 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000700 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
701 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700702
703 // GEMM-conv2d, left matrix is input, right matrix is weight
704 Eigen::array<Eigen::Index, 2> im2col_input_dims;
705 im2col_input_dims[0] = out_batch * out_height * out_width;
706 im2col_input_dims[1] = f_height * f_width * f_in_channels;
707
708 Eigen::array<Eigen::Index, 2> im2col_weight_dims;
709 im2col_weight_dims[0] = f_height * f_width * f_in_channels;
710 im2col_weight_dims[1] = f_out_channels;
711
712 Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
713 bias_reshaped_dims[0] = 1;
714 bias_reshaped_dims[1] = b_out_channels;
715
716 Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
717 weight_zp_bcast_dims[0] = f_height;
718 weight_zp_bcast_dims[1] = f_width;
719 weight_zp_bcast_dims[2] = f_in_channels;
720
721 Eigen::array<Eigen::Index, 2> bias_bcast_dims;
722 bias_bcast_dims[0] = out_batch * out_height * out_width;
Tai Lya641dd52023-08-11 19:58:50 +0000723 bias_bcast_dims[1] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700724
725 Eigen::array<Eigen::Index, 4> col2im_output_dims;
726 col2im_output_dims[0] = out_batch;
727 col2im_output_dims[1] = out_height;
728 col2im_output_dims[2] = out_width;
729 col2im_output_dims[3] = out_channels;
730
731 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
732
TatWai Chong86c403b2022-06-06 20:46:01 -0700733 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
734 pad[0] = std::make_pair(0, 0);
735 pad[1] = std::make_pair(pad_top, pad_bottom);
736 pad[2] = std::make_pair(pad_left, pad_right);
737 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700738
739 TIn input_val = this->input->getTensor();
740 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000741 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700742 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000743 input_val = input_val - (InEigenType)attribute->input_zp();
744 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700745 }
746
TatWai Chong86c403b2022-06-06 20:46:01 -0700747 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700748
Tai Ly307392a2023-05-12 21:42:19 +0000749 TBias bias_val = this->bias->getTensor();
750
751 if (g_func_config.abs_mode)
752 {
753 // in abs_mode: take abs values of conv operands
754 input_padded = input_padded.abs();
755 weight_val = weight_val.abs();
756 bias_val = bias_val.abs();
757 }
758
Eric Kunzee5e26762020-10-13 16:11:07 -0700759 // extract_image_patches() output [N, KH, KW, H * W, C]
760 // need to transpose to [N, H * W, KH, KW, C]
761 ETensor5<InEigenType> input_extract_patches =
762 input_padded
Jerry Gea793f462023-04-11 00:05:02 +0000763 .extract_image_patches(f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700764 .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
765
766 // reshape input to [N * H * W, KH * KW * C]
767 ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
768
769 // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
770 ETensor2<WeightEigenType> im2col_weight =
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000771 weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
Eric Kunzee5e26762020-10-13 16:11:07 -0700772
773 // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
774 // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
Tai Ly307392a2023-05-12 21:42:19 +0000775 ETensor2<OutEigenType> bias_2d =
776 (bias_val.reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700777
778 // output matrix is [N * H * W, C]
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000779 ETensor2<OutEigenType> contracted_result = (im2col_input.template cast<AccEigenType>().contract(
780 im2col_weight.template cast<AccEigenType>(), contract_dims))
781 .template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700782
783 // adding bias
James Ward8b390432022-08-12 20:48:56 +0100784 ETensor2<OutEigenType> biased_output = contracted_result + bias_2d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700785
786 // reshape back to [N, H, W, C]
787 this->output->getTensor() = biased_output.reshape(col2im_output_dims);
788
Tai Lya4d748b2023-03-28 22:06:56 +0000789 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -0700790 {
James Ward8b390432022-08-12 20:48:56 +0100791 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
792 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700793 }
794
795 return GraphNode::eval();
796}
797
Tai Lyf36f2562024-03-14 16:21:29 +0000798template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
799OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
800 TosaAttributeBase* attribute_,
801 uint64_t id_)
Kevin Cheng1533b852021-09-01 12:51:58 -0700802 : GraphNode(sgt_, Op_CONV3D, id_)
803{
804 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000805 setRequiredRank(5, 5);
Kevin Cheng1533b852021-09-01 12:51:58 -0700806
807 INIT_ATTRIBUTE(Conv);
Kevin Cheng1533b852021-09-01 12:51:58 -0700808}
809
Tai Lyf36f2562024-03-14 16:21:29 +0000810template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
811OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpConv3d()
Kevin Cheng1533b852021-09-01 12:51:58 -0700812{
813 if (attribute)
814 delete attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700815}
816
Tai Lyf36f2562024-03-14 16:21:29 +0000817template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
818int OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
Kevin Cheng1533b852021-09-01 12:51:58 -0700819{
820 if (validateRequiredOperands())
821 return 1;
822
823 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
824 {
825 return 1;
826 }
827
828 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
829 if (inputs[2]->getRank() != 1)
830 {
831 printNodeValidationError("OpConv3d: bias tensor must be rank 1");
832 }
833
James Wardd34b3fc2023-01-18 14:51:25 +0000834 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000835 "OpConv3d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700836
Kevin Cheng1533b852021-09-01 12:51:58 -0700837 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
838 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
839 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100840 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng1533b852021-09-01 12:51:58 -0700841
Kevin Cheng9fe17242021-11-10 01:04:39 +0000842 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000843 if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000844 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Kevin Cheng1533b852021-09-01 12:51:58 -0700845 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000846 msg = "OpConv3d: " + msg;
847 printNodeValidationError(msg.c_str());
Kevin Cheng1533b852021-09-01 12:51:58 -0700848 return 1;
849 }
850
Kevin Cheng1533b852021-09-01 12:51:58 -0700851 return 0;
852}
853
Tai Lyf36f2562024-03-14 16:21:29 +0000854template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
855int OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
Kevin Cheng1533b852021-09-01 12:51:58 -0700856{
857 int in_batch = this->input->getShape()[0];
858 int in_depth = this->input->getShape()[1];
859 int in_height = this->input->getShape()[2];
860 int in_width = this->input->getShape()[3];
861 int in_channels = this->input->getShape()[4];
862
863 int f_out_channels = this->weight->getShape()[0];
864 int f_depth = this->weight->getShape()[1];
865 int f_height = this->weight->getShape()[2];
866 int f_width = this->weight->getShape()[3];
867 int f_in_channels = this->weight->getShape()[4];
868
869 int b_out_channels = this->bias->getShape()[0];
870
871 int out_batch = this->output->getShape()[0];
872 int out_depth = this->output->getShape()[1];
873 int out_height = this->output->getShape()[2];
874 int out_width = this->output->getShape()[3];
875 int out_channels = this->output->getShape()[4];
876
877 ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
878 ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
879 in_channels);
880 ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
881 out_channels);
Tai Lya641dd52023-08-11 19:58:50 +0000882 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv3d: bias channel mismatch %d != %d",
883 b_out_channels, out_channels);
Kevin Cheng1533b852021-09-01 12:51:58 -0700884
TatWai Chong86c403b2022-06-06 20:46:01 -0700885 int pad_d0 = this->attribute->pad()[0];
886 int pad_d1 = this->attribute->pad()[1];
887 int pad_top = this->attribute->pad()[2];
888 int pad_bottom = this->attribute->pad()[3];
889 int pad_left = this->attribute->pad()[4];
890 int pad_right = this->attribute->pad()[5];
891
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000892 int stride_d = this->attribute->stride()[0];
893 int stride_y = this->attribute->stride()[1];
894 int stride_x = this->attribute->stride()[2];
TatWai Chong86c403b2022-06-06 20:46:01 -0700895
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000896 int dilation_d = this->attribute->dilation()[0];
897 int dilation_y = this->attribute->dilation()[1];
898 int dilation_x = this->attribute->dilation()[2];
Jerry Gea793f462023-04-11 00:05:02 +0000899
900 // Check Tosa Level
901 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000902 LEVEL_CHECK(dilation_d * f_depth <= tosa_level.MAX_KERNEL,
903 "dilation_d * KD should be smaller than or equal to MAX_KERNEL");
904 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
905 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
906 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
907 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000908 LEVEL_CHECK(pad_d0 <= tosa_level.MAX_KERNEL, "pad_d0 should be smaller than or equal to MAX_KERNEL");
909 LEVEL_CHECK(pad_d1 <= tosa_level.MAX_KERNEL, "pad_d1 should be smaller than or equal to MAX_KERNEL");
910 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
911 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
912 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
913 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
914 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
915 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
916 LEVEL_CHECK(stride_d <= tosa_level.MAX_STRIDE, "stride_d should be smaller than or equal to MAX_STRIDE");
Kevin Cheng1533b852021-09-01 12:51:58 -0700917
918 DEBUG_INFO(
919 OP,
920 "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000921 "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
Kevin Cheng1533b852021-09-01 12:51:58 -0700922 in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
Jerry Gea793f462023-04-11 00:05:02 +0000923 out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_y, stride_x, dilation_d, dilation_y,
924 dilation_x, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
Kevin Cheng1533b852021-09-01 12:51:58 -0700925
TatWai Chong86c403b2022-06-06 20:46:01 -0700926 Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
927 pad[0] = std::make_pair(0, 0);
928 pad[1] = std::make_pair(pad_d0, pad_d1);
929 pad[2] = std::make_pair(pad_top, pad_bottom);
930 pad[3] = std::make_pair(pad_left, pad_right);
931 pad[4] = std::make_pair(0, 0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700932
933 TIn input_val = this->input->getTensor();
934 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000935 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Kevin Cheng1533b852021-09-01 12:51:58 -0700936 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000937 input_val = input_val - (InEigenType)attribute->input_zp();
938 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Kevin Cheng1533b852021-09-01 12:51:58 -0700939 }
940
TatWai Chong86c403b2022-06-06 20:46:01 -0700941 ETensor5<InEigenType> input_padded = input_val.pad(pad);
Kevin Cheng1533b852021-09-01 12:51:58 -0700942
Tai Ly307392a2023-05-12 21:42:19 +0000943 TBias bias_val = this->bias->getTensor();
944
945 if (g_func_config.abs_mode)
946 {
947 // in abs_mode: take abs values of conv operands
948 input_padded = input_padded.abs();
949 weight_val = weight_val.abs();
950 bias_val = bias_val.abs();
951 }
952
Kevin Cheng1533b852021-09-01 12:51:58 -0700953 // 1. initialize with bias
954 Eigen::array<Eigen::Index, 5> reshape_dim;
955 reshape_dim.fill(1);
956 reshape_dim[4] = b_out_channels;
957
958 Eigen::array<Eigen::Index, 5> bcast;
959 bcast[0] = out_batch;
960 bcast[1] = out_depth;
961 bcast[2] = out_height;
962 bcast[3] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +0000963 bcast[4] = (b_out_channels == 1) ? out_channels : 1;
Tai Ly307392a2023-05-12 21:42:19 +0000964 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Kevin Cheng1533b852021-09-01 12:51:58 -0700965
966 // 2. direct convolution
James Ward8b390432022-08-12 20:48:56 +0100967 AccEigenType acc(0.0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700968 int d_idx, h_idx, w_idx;
969
970 for (int ob = 0; ob < out_batch; ob++)
971 {
972 for (int od = 0; od < out_depth; od++)
973 {
974 for (int oh = 0; oh < out_height; oh++)
975 {
976 for (int ow = 0; ow < out_width; ow++)
977 {
978 for (int oc = 0; oc < out_channels; oc++)
979 {
Eric Kunze7edb34c2022-05-16 17:34:40 -0700980 // Initialize accumulator with bias value
James Ward8b390432022-08-12 20:48:56 +0100981 acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc);
Kevin Cheng1533b852021-09-01 12:51:58 -0700982 for (int fd = 0; fd < f_depth; fd++)
983 {
984 d_idx = od * stride_d + fd * dilation_d;
985 for (int fh = 0; fh < f_height; fh++)
986 {
Jerry Gea793f462023-04-11 00:05:02 +0000987 h_idx = oh * stride_y + fh * dilation_y;
Kevin Cheng1533b852021-09-01 12:51:58 -0700988 for (int fw = 0; fw < f_width; fw++)
989 {
Jerry Gea793f462023-04-11 00:05:02 +0000990 w_idx = ow * stride_x + fw * dilation_x;
Kevin Cheng1533b852021-09-01 12:51:58 -0700991 for (int ic = 0; ic < in_channels; ic++)
992 {
993 acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
994 (AccEigenType)weight_val(oc, fd, fh, fw, ic));
995 }
996 }
997 }
998 }
James Ward8b390432022-08-12 20:48:56 +0100999 this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc;
Kevin Cheng1533b852021-09-01 12:51:58 -07001000 }
1001 }
1002 }
1003 }
1004 }
1005
Tai Lya4d748b2023-03-28 22:06:56 +00001006 if (OutDtype == TOSA_REF_TYPE_INT48)
Kevin Cheng1533b852021-09-01 12:51:58 -07001007 {
James Ward8b390432022-08-12 20:48:56 +01001008 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1009 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Kevin Cheng1533b852021-09-01 12:51:58 -07001010 }
1011
1012 return GraphNode::eval();
1013}
1014
Tai Lyf36f2562024-03-14 16:21:29 +00001015template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
1016OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
1017 TosaAttributeBase* attribute_,
1018 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001019 : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001020{
1021 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001022 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001023
Kevin Cheng93a16282021-08-31 16:14:03 -07001024 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001025}
1026
Tai Lyf36f2562024-03-14 16:21:29 +00001027template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
1028OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpDepthwiseConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001029{
1030 if (attribute)
1031 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001032}
1033
Tai Lyf36f2562024-03-14 16:21:29 +00001034template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
1035int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001036{
1037 if (validateRequiredOperands())
1038 return 1;
1039
1040 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1041 {
1042 return 1;
1043 }
1044
1045 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
1046 if (inputs[2]->getRank() != 1)
1047 {
1048 printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
1049 }
1050
James Wardd34b3fc2023-01-18 14:51:25 +00001051 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001052 "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001053
Eric Kunzee5e26762020-10-13 16:11:07 -07001054 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1055 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1056 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001057 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001058
Kevin Cheng9fe17242021-11-10 01:04:39 +00001059 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001060 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001061 weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001062 {
Kevin Cheng9fe17242021-11-10 01:04:39 +00001063 msg = "OpDepthwiseConv2d: " + msg;
1064 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001065 return 1;
1066 }
1067
Eric Kunzee5e26762020-10-13 16:11:07 -07001068 return 0;
1069}
1070
Tai Lyf36f2562024-03-14 16:21:29 +00001071template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
1072int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001073{
1074 int in_batch = this->input->getShape()[0];
1075 int in_height = this->input->getShape()[1];
1076 int in_width = this->input->getShape()[2];
1077 int in_channels = this->input->getShape()[3];
1078
1079 int f_height = this->weight->getShape()[0];
1080 int f_width = this->weight->getShape()[1];
1081 int f_in_channels = this->weight->getShape()[2];
1082 int f_multiplier = this->weight->getShape()[3];
1083
1084 int b_out_channels = this->bias->getShape()[0];
1085
1086 int out_batch = this->output->getShape()[0];
1087 int out_height = this->output->getShape()[1];
1088 int out_width = this->output->getShape()[2];
1089 int out_channels = this->output->getShape()[3];
1090
Kevin Chengacb550f2021-06-29 15:32:19 -07001091 ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1092 ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1093 in_channels);
1094 ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
1095 in_channels * f_multiplier, out_channels);
Tai Lya641dd52023-08-11 19:58:50 +00001096 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
1097 "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001098
TatWai Chong86c403b2022-06-06 20:46:01 -07001099 int pad_top = this->attribute->pad()[0];
1100 int pad_bottom = this->attribute->pad()[1];
1101 int pad_left = this->attribute->pad()[2];
1102 int pad_right = this->attribute->pad()[3];
1103
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001104 int stride_y = this->attribute->stride()[0];
1105 int stride_x = this->attribute->stride()[1];
1106 int dilation_y = this->attribute->dilation()[0];
1107 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001108
1109 // Check Tosa Level
1110 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001111 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
1112 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
1113 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
1114 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00001115 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1116 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1117 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1118 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
1119 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1120 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -07001121
1122 DEBUG_INFO(OP,
1123 "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00001124 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -07001125 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001126 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
1127 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001128
TatWai Chong86c403b2022-06-06 20:46:01 -07001129 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1130 pad[0] = std::make_pair(0, 0);
1131 pad[1] = std::make_pair(pad_top, pad_bottom);
1132 pad[2] = std::make_pair(pad_left, pad_right);
1133 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001134
1135 TIn input_val = this->input->getTensor();
1136 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001137 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001138 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001139 input_val = input_val - (InEigenType)attribute->input_zp();
1140 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001141 }
1142
TatWai Chong86c403b2022-06-06 20:46:01 -07001143 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -07001144
Tai Ly307392a2023-05-12 21:42:19 +00001145 TBias bias_val = this->bias->getTensor();
1146
1147 if (g_func_config.abs_mode)
1148 {
1149 // in abs_mode: take abs values of conv operands
1150 input_padded = input_padded.abs();
1151 weight_val = weight_val.abs();
1152 bias_val = bias_val.abs();
1153 }
1154
Eric Kunzee5e26762020-10-13 16:11:07 -07001155 // GEMM doesn't fit well with DepthwiseConv2d
TatWai Chong86c403b2022-06-06 20:46:01 -07001156 // 1. use extract_image_patches() to handle stride/dilation/pad
Eric Kunzee5e26762020-10-13 16:11:07 -07001157 // 2. perform direct convolution
1158
1159 // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
1160 ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
Jerry Gea793f462023-04-11 00:05:02 +00001161 f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID);
Eric Kunzee5e26762020-10-13 16:11:07 -07001162
1163 Eigen::array<Eigen::Index, 4> reshape_dim;
1164 reshape_dim.fill(1);
1165 reshape_dim[3] = b_out_channels;
1166
1167 Eigen::array<Eigen::Index, 4> bcast;
1168 bcast[0] = out_batch;
1169 bcast[1] = out_height;
1170 bcast[2] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +00001171 bcast[3] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07001172
1173 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00001174 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001175
1176 // 2. direct depthwise convolution
1177 for (int ob = 0; ob < out_batch; ob++)
1178 {
1179 for (int oh = 0; oh < out_height; oh++)
1180 {
1181 for (int ow = 0; ow < out_width; ow++)
1182 {
1183 for (int ic = 0; ic < in_channels; ic++)
1184 {
1185 for (int cm = 0; cm < f_multiplier; cm++)
1186 {
1187 for (int fh = 0; fh < f_height; fh++)
1188 {
1189 for (int fw = 0; fw < f_width; fw++)
1190 {
James Ward8b390432022-08-12 20:48:56 +01001191 // Perform multiplication in AccEigenType then cast to OutEigenType
Eric Kunzebe2e87c2023-08-07 15:16:18 +00001192 this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
1193 (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh,
1194 ic) *
1195 (AccEigenType)weight_val(fh, fw, ic, cm));
Eric Kunzee5e26762020-10-13 16:11:07 -07001196 }
1197 }
1198 }
1199 }
1200 }
1201 }
1202 }
1203
Tai Lya4d748b2023-03-28 22:06:56 +00001204 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001205 {
James Ward8b390432022-08-12 20:48:56 +01001206 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1207 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001208 }
1209
1210 return GraphNode::eval();
1211}
1212
Tai Lya4d748b2023-03-28 22:06:56 +00001213template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001214OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001215 TosaAttributeBase* attribute_,
1216 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001217 : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001218{
1219 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001220 setRequiredRank(2, 2);
Eric Kunzee5e26762020-10-13 16:11:07 -07001221
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001222 INIT_ATTRIBUTE(FullyConnected);
Eric Kunzee5e26762020-10-13 16:11:07 -07001223}
1224
Tai Lya4d748b2023-03-28 22:06:56 +00001225template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001226OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
Eric Kunzee5e26762020-10-13 16:11:07 -07001227{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001228 if (attribute)
1229 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001230}
1231
Tai Lya4d748b2023-03-28 22:06:56 +00001232template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001233int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001234{
1235 if (validateRequiredOperands())
1236 return 1;
1237
1238 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1239 {
1240 return 1;
1241 }
1242
1243 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1244 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1245 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
1246
1247 if (input->getShape()[1] != weight->getShape()[1])
1248 {
1249 printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
1250 return 1;
1251 }
1252
Jack Franklandac40bd12023-11-21 17:08:37 +00001253 if (weight->getShape()[0] != bias->getShape()[0] && bias->getShape()[0] != 1)
Eric Kunzee5e26762020-10-13 16:11:07 -07001254 {
Jack Franklandac40bd12023-11-21 17:08:37 +00001255 printNodeValidationError(
1256 "OpFullyConnected operator bias.shape[0] should match weight.shape[0] or be equal to 1");
Eric Kunzee5e26762020-10-13 16:11:07 -07001257 return 1;
1258 }
1259
James Wardd34b3fc2023-01-18 14:51:25 +00001260 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001261 "OpFullyConnected: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001262
James Ward8b390432022-08-12 20:48:56 +01001263 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001264
Tai Lya4d748b2023-03-28 22:06:56 +00001265 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
1266 "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
1267 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
1268 "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001269
Eric Kunzee5e26762020-10-13 16:11:07 -07001270 return 0;
1271}
1272
Tai Lya4d748b2023-03-28 22:06:56 +00001273template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001274int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001275{
1276 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1277 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1278
1279 Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
1280
Tai Lya641dd52023-08-11 19:58:50 +00001281 int b_out_channels = this->bias->getShape()[0];
1282 int out_channels = this->output->getShape()[1];
1283
1284 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpFullyConnected: bias channels mismatch %d != %d",
1285 b_out_channels, out_channels);
1286
Eric Kunzee5e26762020-10-13 16:11:07 -07001287 Eigen::array<Eigen::Index, 2> bias_reshape;
1288 bias_reshape[0] = 1;
Tai Lya641dd52023-08-11 19:58:50 +00001289 bias_reshape[1] = b_out_channels;
Eric Kunzee5e26762020-10-13 16:11:07 -07001290
1291 Eigen::array<Eigen::Index, 2> bias_bcast;
1292 bias_bcast[0] = this->input->getShape()[0];
Tai Lya641dd52023-08-11 19:58:50 +00001293 bias_bcast[1] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07001294
1295 TIn input_val = this->input->getTensor();
1296 TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
Tai Lya4d748b2023-03-28 22:06:56 +00001297 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001298 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001299 input_val = input_val - (InEigenType)attribute->input_zp();
1300 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001301 }
1302
Tai Ly307392a2023-05-12 21:42:19 +00001303 TBias bias_val = this->bias->getTensor();
1304
1305 if (g_func_config.abs_mode)
1306 {
1307 // in abs_mode: take abs values of conv operands
1308 input_val = input_val.abs();
1309 weight_val = weight_val.abs();
1310 bias_val = bias_val.abs();
1311 }
1312
1313 this->output->getTensor() = input_val.template cast<AccEigenType>()
1314 .contract(weight_val.template cast<AccEigenType>(), dims)
1315 .template cast<OutEigenType>() +
1316 bias_val.reshape(bias_reshape).broadcast(bias_bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001317
Tai Lya4d748b2023-03-28 22:06:56 +00001318 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001319 {
James Ward8b390432022-08-12 20:48:56 +01001320 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1321 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001322 }
1323 return GraphNode::eval();
1324}
1325
Tai Lya4d748b2023-03-28 22:06:56 +00001326template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001327OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001328 : GraphNode(sgt_, Op_MATMUL, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001329{
1330 setRequiredOperands(2, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001331 setRequiredRank(3, 3);
Eric Kunzee5e26762020-10-13 16:11:07 -07001332
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001333 INIT_ATTRIBUTE(MatMul);
Eric Kunzee5e26762020-10-13 16:11:07 -07001334}
1335
Tai Lya4d748b2023-03-28 22:06:56 +00001336template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001337OpMatMul<Dtype, OutDtype>::~OpMatMul()
Eric Kunzee5e26762020-10-13 16:11:07 -07001338{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001339 if (attribute)
1340 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001341}
1342
Tai Lya4d748b2023-03-28 22:06:56 +00001343template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001344int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001345{
1346 if (validateRequiredOperands())
1347 return 1;
1348
1349 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1350 {
1351 return 1;
1352 }
1353
James Wardd34b3fc2023-01-18 14:51:25 +00001354 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001355 "OpMatMul: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001356
Kevin Cheng2d60f002021-06-09 14:18:32 -07001357 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1358 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
James Ward8b390432022-08-12 20:48:56 +01001359 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001360
Kevin Cheng2d60f002021-06-09 14:18:32 -07001361 ASSERT_MEM(a && b && output);
1362
1363 // a: [N, H, C]
1364 // b: [N, C, W]
1365 // c: [N, H, W]
1366
1367 // Check N
1368 if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001369 {
Kevin Cheng2d60f002021-06-09 14:18:32 -07001370 printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
Eric Kunzee5e26762020-10-13 16:11:07 -07001371 return 1;
1372 }
Kevin Cheng2d60f002021-06-09 14:18:32 -07001373 N = a->getShape()[0];
Eric Kunzee5e26762020-10-13 16:11:07 -07001374
Kevin Cheng2d60f002021-06-09 14:18:32 -07001375 // Check C
1376 if (a->getShape()[2] != b->getShape()[1])
1377 {
1378 printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
1379 return 1;
1380 }
1381 C = a->getShape()[2];
1382
1383 // Check H
1384 if (a->getShape()[1] != output->getShape()[1])
1385 {
1386 printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
1387 return 1;
1388 }
1389 H = a->getShape()[1];
1390
1391 // Check W
1392 if (b->getShape()[2] != output->getShape()[2])
1393 {
1394 printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
1395 return 1;
1396 }
1397 W = b->getShape()[2];
Eric Kunzee5e26762020-10-13 16:11:07 -07001398
Tai Lya4d748b2023-03-28 22:06:56 +00001399 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0,
1400 "OpMatMul: A zeropoint must be zero for non int8_t data");
1401 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0,
1402 "OpMatMul: B zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001403
Eric Kunzee5e26762020-10-13 16:11:07 -07001404 return 0;
1405}
1406
Tai Lya4d748b2023-03-28 22:06:56 +00001407template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001408int OpMatMul<Dtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001409{
1410 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1411 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1412
1413 TIn a_val = this->a->getTensor();
1414 TIn b_val = this->b->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001415 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001416 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001417 a_val = a_val - (InEigenType)attribute->a_zp();
1418 b_val = b_val - (InEigenType)attribute->b_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001419 }
1420
Tai Ly307392a2023-05-12 21:42:19 +00001421 if (g_func_config.abs_mode)
1422 {
1423 // in abs_mode: take abs values of matmul operands
1424 a_val = a_val.abs();
1425 b_val = b_val.abs();
1426 }
1427
Kevin Cheng2d60f002021-06-09 14:18:32 -07001428 Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
1429 Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
1430 Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
1431
1432 Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
1433 Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
1434
1435 Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
1436 Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
1437
1438 // Iterate N dimension.
1439 for (int i = 0; i < N; i++)
1440 {
1441 a_begin_array[0] = i;
1442 b_begin_array[0] = i;
1443
1444 TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
1445 TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
1446 TAccRank2 output_rank2_val =
1447 a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
James Ward8b390432022-08-12 20:48:56 +01001448 TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast<OutEigenType>();
Kevin Cheng2d60f002021-06-09 14:18:32 -07001449 if (i == 0)
1450 {
1451 this->output->getTensor() = output_rank3_val;
1452 }
1453 else
1454 {
James Ward8b390432022-08-12 20:48:56 +01001455 TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001456 this->output->getTensor() = temp;
1457 }
1458 }
Eric Kunzee5e26762020-10-13 16:11:07 -07001459
Tai Lya4d748b2023-03-28 22:06:56 +00001460 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001461 {
James Ward8b390432022-08-12 20:48:56 +01001462 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1463 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001464 }
1465
1466 return GraphNode::eval();
1467}
1468
Tai Lya4d748b2023-03-28 22:06:56 +00001469template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001470OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001471 : GraphNode(sgt_, Op_MAX_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001472{
1473 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001474 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001475
Kevin Cheng93a16282021-08-31 16:14:03 -07001476 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -07001477}
1478
Tai Lya4d748b2023-03-28 22:06:56 +00001479template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001480OpMaxPool2d<Dtype>::~OpMaxPool2d()
1481{
1482 if (attribute)
1483 delete attribute;
1484}
1485
Tai Lya4d748b2023-03-28 22:06:56 +00001486template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001487int OpMaxPool2d<Dtype>::checkTensorAttributes()
1488{
1489 if (validateRequiredOperands())
1490 return 1;
1491
1492 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
1493 {
1494 return 1;
1495 }
1496
1497 if (inputs[0]->matchType(*outputs[0]))
1498 {
1499 printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
1500 return 1;
1501 }
1502
1503 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1504 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1505
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001506 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +00001507 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001508 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001509 msg = "OpMaxPool2d: " + msg;
1510 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001511 return 1;
1512 }
1513
1514 return 0;
1515}
1516
Tai Lya4d748b2023-03-28 22:06:56 +00001517template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001518int OpMaxPool2d<Dtype>::eval()
1519{
1520 int in_batch = this->in->getShape()[0];
1521 int in_height = this->in->getShape()[1];
1522 int in_width = this->in->getShape()[2];
1523 int in_channels = this->in->getShape()[3];
1524
1525 int out_batch = this->out->getShape()[0];
1526 int out_height = this->out->getShape()[1];
1527 int out_width = this->out->getShape()[2];
1528 int out_channels = this->out->getShape()[3];
1529
Kevin Chengacb550f2021-06-29 15:32:19 -07001530 ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1531 ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001532
TatWai Chong86c403b2022-06-06 20:46:01 -07001533 int pad_top = this->attribute->pad()[0];
1534 int pad_bottom = this->attribute->pad()[1];
1535 int pad_left = this->attribute->pad()[2];
1536 int pad_right = this->attribute->pad()[3];
1537
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001538 int kernel_y = this->attribute->kernel()[0];
1539 int kernel_x = this->attribute->kernel()[1];
1540 int stride_y = this->attribute->stride()[0];
1541 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001542
1543 // Check Tosa Level
1544 auto tosa_level = g_func_config.tosa_level;
1545 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
1546 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
1547 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1548 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
1549 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1550 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1551 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1552 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -07001553
1554 DEBUG_INFO(OP,
1555 "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
TatWai Chong86c403b2022-06-06 20:46:01 -07001556 "stride=[%d,%d], pad=[%d,%d,%d,%d]",
Jerry Gea793f462023-04-11 00:05:02 +00001557 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
1558 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001559
1560 Eigen::array<Eigen::Index, 2> im2col_input_dims;
Jerry Gea793f462023-04-11 00:05:02 +00001561 im2col_input_dims[0] = kernel_y * kernel_x;
Eric Kunzee5e26762020-10-13 16:11:07 -07001562 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
1563
1564 Eigen::array<Eigen::Index, 4> col2im_output_dims;
1565 col2im_output_dims[0] = out_batch;
1566 col2im_output_dims[1] = out_height;
1567 col2im_output_dims[2] = out_width;
1568 col2im_output_dims[3] = out_channels;
1569
TatWai Chong86c403b2022-06-06 20:46:01 -07001570 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1571 pad[0] = std::make_pair(0, 0);
1572 pad[1] = std::make_pair(pad_top, pad_bottom);
1573 pad[2] = std::make_pair(pad_left, pad_right);
1574 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001575
TatWai Chong86c403b2022-06-06 20:46:01 -07001576 ETensor4<InEigenType> input_padded = this->in->getTensor().pad(pad, std::numeric_limits<InEigenType>::lowest());
Eric Kunzee5e26762020-10-13 16:11:07 -07001577
1578 // extract_image_patches() output [N, KH, KW, H * W, C]
1579 // transpose to [KH, KW, N, H * W, C]
1580 // reshape to [KH * KW, N * H * W * C]
1581 //
1582 // Set the padding value to be the most negative value that can be
1583 // represented by the datatype to ensure that any padding values will be equal
1584 // to or smaller than the actual maximum in the KH x KW patch.
1585 ETensor2<InEigenType> input_extract_patches =
1586 input_padded
Jerry Gea793f462023-04-11 00:05:02 +00001587 .extract_image_patches(kernel_y, kernel_x, stride_y, stride_x, 1, 1, Eigen::PADDING_VALID,
Eric Kunzee5e26762020-10-13 16:11:07 -07001588 std::numeric_limits<InEigenType>::lowest())
1589 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
1590 .reshape(im2col_input_dims);
1591
1592 // Get the maximum of the KHxHW patches along axis 0
1593 Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
1594
1595 // 1D result with [N * H * W * C]
1596 ETensor1<OutEigenType> out_1d(this->out->getElementCount());
1597
1598 // index input_patches with argmax array should give the result
1599 for (size_t i = 0; i < this->out->getElementCount(); i++)
1600 {
1601 out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
1602 }
1603
1604 // reshape result to [N, H, W, C]
1605 this->out->getTensor() = out_1d.reshape(col2im_output_dims);
1606
1607 return GraphNode::eval();
1608}
1609
Tai Lya4d748b2023-03-28 22:06:56 +00001610template <TOSA_REF_TYPE Dtype>
1611OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton57287132023-02-06 14:54:18 +00001612 : GraphNode(sgt_, Op_FFT2D, id_)
1613{
1614 setRequiredOperands(2, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001615 setRequiredRank(3, 3);
Luke Hutton57287132023-02-06 14:54:18 +00001616
1617 INIT_ATTRIBUTE(FFT);
1618}
1619
Tai Lya4d748b2023-03-28 22:06:56 +00001620template <TOSA_REF_TYPE Dtype>
1621OpFFT2d<Dtype>::~OpFFT2d()
1622{
Luke Hutton57287132023-02-06 14:54:18 +00001623 if (attribute)
1624 delete attribute;
1625}
1626
Tai Lya4d748b2023-03-28 22:06:56 +00001627template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001628int OpFFT2d<Dtype>::checkTensorAttributes()
1629{
1630 if (validateRequiredOperands())
1631 return 1;
1632
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001633 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]) ||
1634 validateRequiredRank(outputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001635 {
1636 return 1;
1637 }
1638
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001639 if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) || inputs[0]->matchType(*inputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001640 {
1641 printNodeValidationError("OpFFT2d: input and output tensor type mismatch");
1642 return 1;
1643 }
1644
1645 in_real = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1646 in_imag = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
1647 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1648 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1649
1650 ASSERT_MEM(in_real && in_imag && out_real && out_imag);
1651
1652 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001653 if (check_fft_shape(in_real->getShape(), in_imag->getShape(), out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton57287132023-02-06 14:54:18 +00001654 {
1655 msg = "OpFFT2d: " + msg;
1656 printNodeValidationError(msg.c_str());
1657 return 1;
1658 }
1659
1660 return 0;
1661}
1662
Tai Lya4d748b2023-03-28 22:06:56 +00001663template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001664int OpFFT2d<Dtype>::eval()
1665{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001666 int in_real_batch = this->in_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001667 int in_real_height = this->in_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001668 int in_real_width = this->in_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001669
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001670 int in_imag_batch = this->in_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001671 int in_imag_height = this->in_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001672 int in_imag_width = this->in_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001673
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001674 int out_real_batch = this->out_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001675 int out_real_height = this->out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001676 int out_real_width = this->out_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001677
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001678 int out_imag_batch = this->out_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001679 int out_imag_height = this->out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001680 int out_imag_width = this->out_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001681
Jerry Gea793f462023-04-11 00:05:02 +00001682 // Check Tosa Level
1683 auto tosa_level = g_func_config.tosa_level;
1684 LEVEL_CHECK(in_real_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1685 LEVEL_CHECK(in_real_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1686
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001687 DEBUG_INFO(OP, "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]",
1688 in_real_batch, in_real_height, in_real_width, in_imag_batch, in_imag_height, in_imag_width,
1689 out_real_batch, out_real_height, out_real_width, out_imag_batch, out_imag_height, out_imag_width);
Luke Hutton57287132023-02-06 14:54:18 +00001690
Jeremy Johnsonc8330812024-01-18 16:57:28 +00001691 OutEigenType sum_real, sum_imag, sign_val = 1.0;
1692 OutEigenType a, a_cos, a_sin, v_ir;
Luke Hutton57287132023-02-06 14:54:18 +00001693
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001694 if (attribute->inverse())
1695 {
Luke Hutton57287132023-02-06 14:54:18 +00001696 sign_val = -1.0;
1697 }
1698
Tai Ly307392a2023-05-12 21:42:19 +00001699 TIn in_real_val = this->in_real->getTensor();
1700 TIn in_imag_val = this->in_imag->getTensor();
1701
1702 if (g_func_config.abs_mode)
1703 {
1704 // in abs_mode: take abs values of real and imag operands
1705 in_real_val = in_real_val.abs();
1706 in_imag_val = in_imag_val.abs();
1707 }
1708
Luke Hutton57287132023-02-06 14:54:18 +00001709 for (int n = 0; n < in_real_batch; n++)
1710 {
1711 for (int oy = 0; oy < out_real_height; oy++)
1712 {
1713 for (int ox = 0; ox < out_real_width; ox++)
1714 {
1715 sum_real = 0.0;
1716 sum_imag = 0.0;
1717 for (int iy = 0; iy < in_real_height; iy++)
1718 {
1719 for (int ix = 0; ix < in_real_width; ix++)
1720 {
Tai Ly307392a2023-05-12 21:42:19 +00001721 OutEigenType val_real = in_real_val(n, iy, ix);
1722 OutEigenType val_imag = in_imag_val(n, iy, ix);
Jeremy Johnsonc8330812024-01-18 16:57:28 +00001723 // Perform the periodic calculation in integer maths to keep
1724 // the accuracy of the co-efficients similar for FP32 normal
1725 // and FP64 precise mode
1726 int32_t ay = (static_cast<int64_t>(iy) * static_cast<int64_t>(oy)) % in_real_height;
1727 int32_t ax = (static_cast<int64_t>(ix) * static_cast<int64_t>(ox)) % in_real_width;
1728
1729 // Use explicit cast to ensure intermediate calculations are completed using OutEigenType
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001730 a = sign_val * 2 * M_PI *
Jeremy Johnsonc8330812024-01-18 16:57:28 +00001731 ((OutEigenType)ay / in_real_height + (OutEigenType)ax / in_real_width);
1732 // Calculate weight values
1733 a_cos = cos(a);
1734 a_sin = sin(a);
1735 if (g_func_config.abs_mode)
1736 {
1737 // Bounded op - Use abs weight values
1738 a_cos = std::abs(a_cos);
1739 a_sin = std::abs(a_sin);
1740 // Bounded op - Use abs real value for imaginary calc
1741 v_ir = val_real;
1742 }
1743 else
1744 {
1745 // Normal op - Use negative real value for imaginary calc
1746 v_ir = -val_real;
1747 }
1748 sum_real += val_real * a_cos + val_imag * a_sin;
1749 sum_imag += v_ir * a_sin + val_imag * a_cos;
Luke Hutton57287132023-02-06 14:54:18 +00001750 }
1751 }
1752 this->out_real->getTensor()(n, oy, ox) = sum_real;
1753 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1754 }
1755 }
1756 }
1757
1758 return GraphNode::eval();
1759}
1760
Tai Lya4d748b2023-03-28 22:06:56 +00001761template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001762OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton261b7b62023-01-10 14:50:31 +00001763 : GraphNode(sgt_, Op_RFFT2D, id_)
1764{
1765 setRequiredOperands(1, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001766 setRequiredRank(3, 3);
Tai Lyfd8fde82023-11-13 20:18:14 +00001767
1768 INIT_ATTRIBUTE(RFFT);
Luke Hutton261b7b62023-01-10 14:50:31 +00001769}
1770
Tai Lya4d748b2023-03-28 22:06:56 +00001771template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001772OpRFFT2d<Dtype>::~OpRFFT2d()
Tai Lyfd8fde82023-11-13 20:18:14 +00001773{
1774 if (attribute)
1775 delete attribute;
1776}
Luke Hutton261b7b62023-01-10 14:50:31 +00001777
Tai Lya4d748b2023-03-28 22:06:56 +00001778template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001779int OpRFFT2d<Dtype>::checkTensorAttributes()
1780{
1781 if (validateRequiredOperands())
1782 return 1;
1783
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001784 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1]))
Luke Hutton261b7b62023-01-10 14:50:31 +00001785 {
1786 return 1;
1787 }
1788
1789 if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
1790 {
1791 printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
1792 return 1;
1793 }
1794
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001795 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
Luke Hutton261b7b62023-01-10 14:50:31 +00001796 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1797 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1798
1799 ASSERT_MEM(in && out_real && out_imag);
1800
Luke Hutton57287132023-02-06 14:54:18 +00001801 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001802 if (check_fft_shape(in->getShape(), {}, out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton261b7b62023-01-10 14:50:31 +00001803 {
Luke Hutton57287132023-02-06 14:54:18 +00001804 msg = "OpRFFT2d: " + msg;
1805 printNodeValidationError(msg.c_str());
Luke Hutton261b7b62023-01-10 14:50:31 +00001806 return 1;
1807 }
1808
1809 return 0;
1810}
1811
Tai Lya4d748b2023-03-28 22:06:56 +00001812template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001813int OpRFFT2d<Dtype>::eval()
1814{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001815 int32_t in_batch = in->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001816 int32_t in_height = in->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001817 int32_t in_width = in->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001818
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001819 int32_t out_real_batch = out_real->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001820 int32_t out_real_height = out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001821 int32_t out_real_width = out_real->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001822
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001823 int32_t out_imag_batch = out_imag->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001824 int32_t out_imag_height = out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001825 int32_t out_imag_width = out_imag->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001826
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00001827 int32_t half_in_height = in_height / 2;
1828 int32_t half_in_width = in_width / 2;
1829
Jerry Gea793f462023-04-11 00:05:02 +00001830 // Check Tosa Level
1831 auto tosa_level = g_func_config.tosa_level;
1832 LEVEL_CHECK(in_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1833 LEVEL_CHECK(in_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1834
Luke Hutton261b7b62023-01-10 14:50:31 +00001835 DEBUG_INFO(OP,
1836 "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], "
1837 "output_imag.shape=[%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001838 in_batch, in_height, in_width, out_real_batch, out_real_height, out_real_width, out_imag_batch,
1839 out_imag_height, out_imag_width);
Luke Hutton261b7b62023-01-10 14:50:31 +00001840
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00001841 OutEigenType sum_real, sum_imag;
1842 OutEigenType a, a_cos, a_sin, v_ir;
Luke Hutton261b7b62023-01-10 14:50:31 +00001843
Tai Ly307392a2023-05-12 21:42:19 +00001844 TIn in_val = this->in->getTensor();
1845
1846 if (g_func_config.abs_mode)
1847 {
1848 // in abs_mode: take abs values of in operand
1849 in_val = in_val.abs();
1850 }
1851
Luke Hutton261b7b62023-01-10 14:50:31 +00001852 for (int n = 0; n < in_batch; n++)
1853 {
1854 for (int oy = 0; oy < out_real_height; oy++)
1855 {
1856 for (int ox = 0; ox < out_real_width; ox++)
1857 {
1858 sum_real = 0.0;
1859 sum_imag = 0.0;
1860 for (int iy = 0; iy < in_height; iy++)
1861 {
1862 for (int ix = 0; ix < in_width; ix++)
1863 {
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00001864 OutEigenType val = in_val(n, iy, ix);
1865 // Perform the periodic calculation in integer maths to keep
1866 // the accuracy of the co-efficients similar for FP32 normal
1867 // and FP64 precise mode
1868 int32_t ay = (static_cast<int64_t>(iy) * static_cast<int64_t>(oy)) % in_height;
1869 int32_t ax = (static_cast<int64_t>(ix) * static_cast<int64_t>(ox)) % in_width;
1870
1871 // Use explicit cast to ensure intermediate calculations are completed using OutEigenType
1872 a = 2 * M_PI * ((OutEigenType)ay / in_height + (OutEigenType)ax / in_width);
1873
1874 // Calculate weight values (co-efficients)
1875 a_cos = cos(a);
1876 a_sin = sin(a);
1877
1878 if (g_func_config.abs_mode)
1879 {
1880 // Bounded op - Use abs weight values
1881 a_cos = std::abs(a_cos);
1882 a_sin = std::abs(a_sin);
1883 // Bounded op - Use abs real value for imaginary calc
1884 v_ir = val;
1885 }
1886 else
1887 {
1888 // Normal op - Use negative real value for imaginary calc
1889 v_ir = -val;
1890 }
1891 sum_real += val * a_cos;
1892 // Imaginary values with locations (0,0), (0,W/2), (H/2,0) and (H/2,W/2) are zero.
1893 // But due to sin(M_PI) not returning 0 because of M_PI being approximate, only
1894 // add to the imaginary sum when not processing these locations.
Dmitriy Smirnov0306fb62024-04-12 14:18:11 +01001895 if ((in_height > 1 && (ay % (half_in_height)) > 0) ||
1896 (in_width > 1 && (ax % (half_in_width)) > 0))
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00001897 {
1898 sum_imag += v_ir * a_sin;
1899 }
Luke Hutton261b7b62023-01-10 14:50:31 +00001900 }
1901 }
1902 this->out_real->getTensor()(n, oy, ox) = sum_real;
1903 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1904 }
1905 }
1906 }
1907
1908 return GraphNode::eval();
1909}
1910
Tai Lyf36f2562024-03-14 16:21:29 +00001911template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
1912OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
1913 TosaAttributeBase* attribute_,
1914 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001915 : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001916{
1917 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001918 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001919
Kevin Cheng93a16282021-08-31 16:14:03 -07001920 INIT_ATTRIBUTE(TransposeConv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001921}
1922
Tai Lyf36f2562024-03-14 16:21:29 +00001923template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
1924OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpTransposeConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001925{
1926 if (attribute)
1927 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001928}
1929
Tai Lyf36f2562024-03-14 16:21:29 +00001930template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
1931int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001932{
1933 if (validateRequiredOperands())
1934 return 1;
1935
1936 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1937 {
1938 return 1;
1939 }
1940
James Wardd34b3fc2023-01-18 14:51:25 +00001941 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001942 "OpTransposeConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001943
Eric Kunzee5e26762020-10-13 16:11:07 -07001944 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1945 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1946 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001947 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001948
TatWai Chong24594f52022-06-08 00:48:04 -07001949 if (attribute->out_pad().size() != 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001950 {
TatWai Chong24594f52022-06-08 00:48:04 -07001951 printNodeValidationError("OpTransposeConv2d: illegal size for attribute out_pad");
Eric Kunzee5e26762020-10-13 16:11:07 -07001952 return 1;
1953 }
1954
1955 if (attribute->stride().size() != 2)
1956 {
1957 printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
1958 return 1;
1959 }
1960
Kevin Cheng9fe17242021-11-10 01:04:39 +00001961 for (int32_t i : attribute->stride())
1962 {
1963 if (i < 1)
1964 {
1965 printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one");
1966 return 1;
1967 }
1968 }
1969
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001970 int32_t IH = input->getShape()[1];
1971 int32_t IW = input->getShape()[2];
1972 int32_t OH = output->getShape()[1];
1973 int32_t OW = output->getShape()[2];
1974
1975 int32_t stride_y = attribute->stride()[0];
1976 int32_t stride_x = attribute->stride()[1];
1977 int32_t kernel_h = weight->getShape()[1];
1978 int32_t kernel_w = weight->getShape()[2];
1979
TatWai Chong24594f52022-06-08 00:48:04 -07001980 int32_t out_pad_top = attribute->out_pad()[0];
1981 int32_t out_pad_bottom = attribute->out_pad()[1];
1982 int32_t out_pad_left = attribute->out_pad()[2];
1983 int32_t out_pad_right = attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001984
Eric Kunzec1a97832022-07-01 16:56:09 -07001985 for (size_t i = 0; i < attribute->out_pad().size(); i++)
1986 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001987 ERROR_IF(attribute->out_pad()[i] <= -(weight->getShape()[(i / 2) + 1]),
1988 "OpTransposeConv2d: At least one out_pad value is larger than kernel size");
Eric Kunzec1a97832022-07-01 16:56:09 -07001989 }
1990
1991 int32_t H = (IH - 1) * stride_y + out_pad_top + out_pad_bottom + kernel_h;
1992 int32_t W = (IW - 1) * stride_x + out_pad_left + out_pad_right + kernel_w;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001993
1994 if ((OH != H) || (OW != W))
1995 {
1996 std::string msg = "OpTransposeConv2d: Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001997 std::to_string(H) + "," + std::to_string(W) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001998 printNodeValidationError(msg.c_str());
1999 return 1;
2000 }
2001
Tai Lya4d748b2023-03-28 22:06:56 +00002002 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
2003 "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
2004 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
2005 "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07002006
Eric Kunzee5e26762020-10-13 16:11:07 -07002007 return 0;
2008}
2009
Tai Lyf36f2562024-03-14 16:21:29 +00002010template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
2011int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07002012{
2013 int in_batch = this->input->getShape()[0];
2014 int in_height = this->input->getShape()[1];
2015 int in_width = this->input->getShape()[2];
2016 int in_channels = this->input->getShape()[3];
2017
2018 int f_out_channels = this->weight->getShape()[0];
2019 int f_height = this->weight->getShape()[1];
2020 int f_width = this->weight->getShape()[2];
2021 int f_in_channels = this->weight->getShape()[3];
2022
2023 int b_out_channels = this->bias->getShape()[0];
2024
2025 int out_batch = this->output->getShape()[0];
2026 int out_height = this->output->getShape()[1];
2027 int out_width = this->output->getShape()[2];
2028 int out_channels = this->output->getShape()[3];
2029
TatWai Chong24594f52022-06-08 00:48:04 -07002030 int out_pad_top = this->attribute->out_pad()[0];
2031 int out_pad_bottom = this->attribute->out_pad()[1];
2032 int out_pad_left = this->attribute->out_pad()[2];
2033 int out_pad_right = this->attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002034
Jerry Gea793f462023-04-11 00:05:02 +00002035 int stride_y = this->attribute->stride()[0];
2036 int stride_x = this->attribute->stride()[1];
Eric Kunzee5e26762020-10-13 16:11:07 -07002037
Kevin Chengacb550f2021-06-29 15:32:19 -07002038 ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
2039 ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
2040 in_channels);
2041 ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
2042 f_out_channels, out_channels);
Tai Lya641dd52023-08-11 19:58:50 +00002043 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
2044 "OpTransposeConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07002045
Jerry Gea793f462023-04-11 00:05:02 +00002046 // Check Tosa Level
2047 auto tosa_level = g_func_config.tosa_level;
2048 LEVEL_CHECK(f_height <= tosa_level.MAX_KERNEL, "KH should be smaller than or equal to MAX_KERNEL");
2049 LEVEL_CHECK(f_width <= tosa_level.MAX_KERNEL, "KW should be smaller than or equal to MAX_KERNEL");
2050 LEVEL_CHECK(out_pad_top <= tosa_level.MAX_KERNEL, "out_pad_top should be smaller than or equal to MAX_KERNEL");
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002051 LEVEL_CHECK(out_pad_bottom <= tosa_level.MAX_KERNEL,
2052 "out_pad_bottom should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00002053 LEVEL_CHECK(out_pad_left <= tosa_level.MAX_KERNEL, "out_pad_left should be smaller than or equal to MAX_KERNEL");
2054 LEVEL_CHECK(out_pad_right <= tosa_level.MAX_KERNEL, "out_pad_right should be smaller than or equal to MAX_KERNEL");
2055 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
2056 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
2057
Eric Kunzee5e26762020-10-13 16:11:07 -07002058 DEBUG_INFO(OP,
2059 "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00002060 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002061 in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch,
2062 out_height, out_width, out_channels, stride_y, stride_x, out_pad_top, out_pad_bottom, out_pad_left,
2063 out_pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07002064
2065 TIn input_val = this->input->getTensor();
2066 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00002067 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002068 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002069 input_val = input_val - (InEigenType)attribute->input_zp();
2070 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07002071 }
2072
Tai Ly307392a2023-05-12 21:42:19 +00002073 TBias bias_val = this->bias->getTensor();
2074
2075 if (g_func_config.abs_mode)
2076 {
2077 // in abs_mode: take abs values of conv operands
2078 input_val = input_val.abs();
2079 weight_val = weight_val.abs();
2080 bias_val = bias_val.abs();
2081 }
2082
Eric Kunzee5e26762020-10-13 16:11:07 -07002083 Eigen::array<Eigen::Index, 4> reshape_dim;
2084 reshape_dim.fill(1);
2085 reshape_dim[3] = b_out_channels;
2086
2087 Eigen::array<Eigen::Index, 4> bcast;
2088 bcast[0] = out_batch;
2089 bcast[1] = out_height;
2090 bcast[2] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +00002091 bcast[3] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07002092
2093 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00002094 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07002095
2096 int out_x_origin, out_y_origin;
2097 int out_x, out_y;
2098
2099 // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
2100 for (int ob = 0; ob < out_batch; ob++)
2101 {
2102 for (int ih = 0; ih < in_height; ih++)
2103 {
2104 for (int iw = 0; iw < in_width; iw++)
2105 {
Jerry Gea793f462023-04-11 00:05:02 +00002106 out_x_origin = iw * stride_x + out_pad_left;
2107 out_y_origin = ih * stride_y + out_pad_top;
Eric Kunzee5e26762020-10-13 16:11:07 -07002108 for (int ic = 0; ic < in_channels; ic++)
2109 {
2110 for (int fh = 0; fh < f_height; fh++)
2111 {
2112 for (int fw = 0; fw < f_width; fw++)
2113 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002114 out_x = out_x_origin + fw;
2115 out_y = out_y_origin + fh;
Eric Kunzee5e26762020-10-13 16:11:07 -07002116 for (int oc = 0; oc < out_channels; oc++)
2117 {
2118 if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
2119 {
2120 this->output->getTensor()(ob, out_y, out_x, oc) +=
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002121 (OutEigenType)((AccEigenType)input_val(ob, ih, iw, ic) *
2122 (AccEigenType)weight_val(oc, fh, fw, ic));
Eric Kunzee5e26762020-10-13 16:11:07 -07002123 }
2124 }
2125 }
2126 }
2127 }
2128 }
2129 }
2130 }
2131
Tai Lya4d748b2023-03-28 22:06:56 +00002132 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07002133 {
James Ward8b390432022-08-12 20:48:56 +01002134 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
2135 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07002136 }
2137
2138 return GraphNode::eval();
2139}
2140
2141// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +01002142DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002143DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002144DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002145DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002146DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002147DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +00002148DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3);
2149DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -07002150
James Wardd34b3fc2023-01-18 14:51:25 +00002151DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
2152DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
2153DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
2154DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
2155DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
2156DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +00002157DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +00002158DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E4M3, FP16);
2159DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E5M2, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -07002160
Tai Lyf36f2562024-03-14 16:21:29 +00002161// [in_t, weight_t, acc_t, out_t]
2162DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP16, FP16);
2163DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP32, FP16);
2164DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, BF16, BF16, FP32, BF16);
2165DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP32, FP32, FP32, FP32);
2166DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT4, INT32, INT32);
2167DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT8, INT32, INT32);
2168DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT16, INT8, INT48, INT48);
2169DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP64, FP64, FP64, FP64);
2170DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16, FP16);
2171DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -07002172
Tai Lyf36f2562024-03-14 16:21:29 +00002173DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP16, FP16);
2174DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP32, FP16);
2175DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, BF16, BF16, FP32, BF16);
2176DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP32, FP32, FP32, FP32);
2177DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT4, INT32, INT32);
2178DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT8, INT32, INT32);
2179DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT16, INT8, INT48, INT48);
2180DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP64, FP64, FP64, FP64);
2181DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16, FP16);
2182DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16, FP16);
Kevin Cheng1533b852021-09-01 12:51:58 -07002183
Tai Lyf36f2562024-03-14 16:21:29 +00002184DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16, FP16);
2185DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32, FP16);
2186DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32, BF16);
2187DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32, FP32);
2188DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32, INT32);
2189DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32, INT32);
2190DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48, INT48);
2191DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64, FP64);
2192DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16, FP16);
2193DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -07002194
Luke Hutton57287132023-02-06 14:54:18 +00002195DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002196DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
Luke Hutton57287132023-02-06 14:54:18 +00002197
James Wardd34b3fc2023-01-18 14:51:25 +00002198DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
2199DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
2200DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
2201DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
2202DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
2203DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
2204DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002205DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
Won Jeone67115e2024-03-13 19:18:08 +00002206DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP8E4M3, FP8E4M3, FP16);
2207DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP8E5M2, FP8E5M2, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -07002208
James Wardd34b3fc2023-01-18 14:51:25 +00002209DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
2210DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
2211DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
2212DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
2213DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
2214DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002215DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +00002216DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E4M3, FP16);
2217DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E5M2, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -07002218
James Ward8b390432022-08-12 20:48:56 +01002219DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002220DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002221DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002222DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002223DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002224DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +00002225DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E4M3);
2226DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -07002227
Luke Hutton261b7b62023-01-10 14:50:31 +00002228DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002229DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
Luke Hutton261b7b62023-01-10 14:50:31 +00002230
Tai Lyf36f2562024-03-14 16:21:29 +00002231// [in_t, weight_t, acc_t, out_t]
2232DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP16, FP16);
2233DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP32, FP16);
2234DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, BF16, BF16, FP32, BF16);
2235DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP32, FP32, FP32, FP32);
2236DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT4, INT32, INT32);
2237DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT8, INT32, INT32);
2238DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT16, INT8, INT48, INT48);
2239DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP64, FP64, FP64, FP64);
2240DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16, FP16);
2241DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16, FP16);