blob: 65e05db7500f371d6c7b7536b8950e9eb520a991 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Hutton261b7b62023-01-10 14:50:31 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tensor_ops.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000017#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070018#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
Kevin Cheng9fe17242021-11-10 01:04:39 +000025int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
26 std::vector<int32_t> input_shape,
27 std::vector<int32_t> output_shape,
28 std::string& msg)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000029{
TatWai Chong86c403b2022-06-06 20:46:01 -070030 if (attribute->pad().size() != 4)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000031 {
32 msg = "illegal size for attribute padding";
33 return 1;
34 }
35
36 if (attribute->kernel().size() != 2)
37 {
38 msg = "illegal size for attribute kernel";
39 return 1;
40 }
41
42 if (attribute->stride().size() != 2)
43 {
44 msg = "illegal size for attribute stride";
45 return 1;
46 }
47
TatWai Chong86c403b2022-06-06 20:46:01 -070048 for (int32_t i : attribute->pad())
Kevin Cheng7eb93d72021-10-09 01:26:08 +000049 {
50 if (i < 0)
51 {
52 msg = "At least one pad is smaller than zero";
53 return 1;
54 }
55 }
56
57 for (int32_t i : attribute->kernel())
58 {
59 if (i < 1)
60 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000061 msg = "At least one kernel dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000062 return 1;
63 }
64 }
65
66 for (int32_t i : attribute->stride())
67 {
68 if (i < 1)
69 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000070 msg = "At least one stride dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000071 return 1;
72 }
73 }
74
75 int32_t IH = input_shape[1];
76 int32_t IW = input_shape[2];
77 int32_t OH = output_shape[1];
78 int32_t OW = output_shape[2];
79
TatWai Chong86c403b2022-06-06 20:46:01 -070080 int32_t pad_top = attribute->pad()[0];
81 int32_t pad_bottom = attribute->pad()[1];
82 int32_t pad_left = attribute->pad()[2];
83 int32_t pad_right = attribute->pad()[3];
Kevin Cheng7eb93d72021-10-09 01:26:08 +000084
85 int32_t stride_y = attribute->stride()[0];
86 int32_t stride_x = attribute->stride()[1];
87 int32_t kernel_y = attribute->kernel()[0];
88 int32_t kernel_x = attribute->kernel()[1];
89
90 if (pad_top >= kernel_y || pad_bottom >= kernel_y || pad_left >= kernel_x || pad_right >= kernel_x)
91 {
92 msg = "At least one pad is >= kernel dimension";
93 return 1;
94 }
95
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010096 int32_t full_H = IH + pad_top + pad_bottom - kernel_y;
97 int32_t full_W = IW + pad_left + pad_right - kernel_x;
98
Jerry Ge9c9c8da2023-07-19 23:08:16 +000099 if ((full_H % stride_y != 0) || (full_W % stride_x != 0))
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000100 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100101 msg = "Parameters must yield exact integer output dimensions";
102 return 1;
103 }
104
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000105 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100106 {
107 msg = "Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000108 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000109 return 1;
110 }
111
112 return 0;
113}
114
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000115int check_conv_attribute(tosa::TosaConvAttribute* attribute,
Tai Lya4d748b2023-03-28 22:06:56 +0000116 uint32_t conv_dimension,
117 std::vector<int32_t> input_shape,
118 std::vector<int32_t> output_shape,
119 std::vector<int32_t> weights,
120 uint32_t offset_kernel,
121 TOSA_REF_TYPE InDtype,
122 TOSA_REF_TYPE WeightDtype,
123 std::string& msg)
Kevin Cheng9fe17242021-11-10 01:04:39 +0000124{
TatWai Chong86c403b2022-06-06 20:46:01 -0700125 if (attribute->pad().size() != (2 * conv_dimension))
Kevin Cheng9fe17242021-11-10 01:04:39 +0000126 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700127 msg = "Illegal size for attribute pad";
Kevin Cheng9fe17242021-11-10 01:04:39 +0000128 return 1;
129 }
130
131 if (attribute->stride().size() != conv_dimension)
132 {
133 msg = "Illegal size for attribute stride";
134 return 1;
135 }
136
137 if (attribute->dilation().size() != conv_dimension)
138 {
139 msg = "Illegal size for attribute dilation";
140 return 1;
141 }
142
TatWai Chong86c403b2022-06-06 20:46:01 -0700143 for (int32_t i : attribute->pad())
Kevin Cheng9fe17242021-11-10 01:04:39 +0000144 {
145 if (i < 0)
146 {
147 msg = "At least one pad is smaller than zero";
148 return 1;
149 }
150 }
151
152 for (int32_t i : attribute->stride())
153 {
154 if (i < 1)
155 {
156 msg = "At least one stride dimension is smaller than one";
157 return 1;
158 }
159 }
160
161 for (int32_t i : attribute->dilation())
162 {
163 if (i < 1)
164 {
165 msg = "At least one dilation dimension is smaller than one";
166 return 1;
167 }
168 }
169
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100170 ASSERT_MSG(conv_dimension == 2 || conv_dimension == 3, "Unsupported convolution dimension")
171
TatWai Chongfd629052022-07-25 04:01:58 +0000172 int32_t offset_d = conv_dimension == 3 ? 1 : 0;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000173 int32_t ID = conv_dimension == 3 ? input_shape[1] : 1;
174 int32_t IH = input_shape[1 + offset_d];
175 int32_t IW = input_shape[2 + offset_d];
176 int32_t OD = conv_dimension == 3 ? output_shape[1] : 1;
177 int32_t OH = output_shape[1 + offset_d];
178 int32_t OW = output_shape[2 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100179
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000180 int32_t stride_d = conv_dimension == 3 ? attribute->stride()[0] : 1;
181 int32_t stride_y = attribute->stride()[0 + offset_d];
182 int32_t stride_x = attribute->stride()[1 + offset_d];
183 int32_t kernel_d = conv_dimension == 3 ? weights[offset_kernel] : 1;
184 int32_t kernel_h = weights[offset_kernel + offset_d];
185 int32_t kernel_w = weights[offset_kernel + 1 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100186 int32_t dilation_d = conv_dimension == 3 ? attribute->dilation()[0] : 1;
187 int32_t dilation_y = attribute->dilation()[0 + offset_d];
188 int32_t dilation_x = attribute->dilation()[1 + offset_d];
189
190 offset_d *= 2;
TatWai Chong86c403b2022-06-06 20:46:01 -0700191 int32_t pad_d0 = conv_dimension == 3 ? attribute->pad()[0] : 0;
192 int32_t pad_d1 = conv_dimension == 3 ? attribute->pad()[1] : 0;
193 int32_t pad_top = attribute->pad()[0 + offset_d];
194 int32_t pad_bottom = attribute->pad()[1 + offset_d];
195 int32_t pad_left = attribute->pad()[2 + offset_d];
196 int32_t pad_right = attribute->pad()[3 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100197
198 int32_t full_D = ID - 1 + pad_d0 + pad_d1 - (kernel_d - 1) * dilation_d;
199 int32_t full_H = IH - 1 + pad_top + pad_bottom - (kernel_h - 1) * dilation_y;
200 int32_t full_W = IW - 1 + pad_left + pad_right - (kernel_w - 1) * dilation_x;
201
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000202 if ((full_H % stride_y != 0) || (full_W % stride_x != 0) || (full_D % stride_d != 0))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100203 {
204 msg = "Parameters must yield exact integer output dimensions";
205 return 1;
206 }
207
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000208 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1) || (OD != (full_D / stride_d) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100209 {
210 std::string msg_d = "";
211 if (conv_dimension == 3)
212 {
213 msg_d += std::to_string((full_D / stride_d) + 1) + ",";
214 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000215 msg = "Mismatch between output shape provided and expected output shape (" + msg_d +
216 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100217 return 1;
218 }
219
Tai Lya4d748b2023-03-28 22:06:56 +0000220 if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0)
221 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000222 msg = "Input zero point must be zero for non-int8 data";
223 return 1;
224 }
Tai Lya4d748b2023-03-28 22:06:56 +0000225 if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0)
226 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 msg = "Weight zero point must be zero for non-int8 data";
228 return 1;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000229 }
230
231 return 0;
232}
233
Luke Hutton57287132023-02-06 14:54:18 +0000234int check_fft_shape(const std::vector<int32_t>& in_real,
235 const std::vector<int32_t>& in_imag,
236 const std::vector<int32_t>& out_real,
237 const std::vector<int32_t>& out_imag,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000238 std::string& msg)
239{
240 const bool is_rfft = in_imag.empty();
241 auto is_power_of_two = [](int32_t n) -> bool { return (n & (n - 1)) == 0 && n > 0; };
Luke Hutton57287132023-02-06 14:54:18 +0000242
243 if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2]))
244 {
245 msg = "Input height and width must be a power of two";
246 return 1;
247 }
248
249 // RFFT does not have a second input
250 if (!is_rfft)
251 {
252 bool input_check = true;
253 for (size_t i = 0; i < in_real.size(); i++)
254 {
255 if (in_real[i] != in_imag[i])
256 {
257 input_check = false;
258 break;
259 }
260 }
261 if (!input_check)
262 {
263 msg = "Mismatch between real input shape and imaginary input shape";
264 return 1;
265 }
266 }
267
268 bool output_check = true;
269 for (size_t i = 0; i < out_real.size(); i++)
270 {
271 if (out_real[i] != out_imag[i])
272 {
273 output_check = false;
274 break;
275 }
276 }
277 if (!output_check)
278 {
279 msg = "Mismatch between real output shape and imaginary output shape";
280 return 1;
281 }
282
283 if (in_real[0] != out_real[0])
284 {
285 msg = "Input and output batch size don't match";
286 return 1;
287 }
288 if (in_real[1] != out_real[1])
289 {
290 msg = "Input and output height don't match";
291 return 1;
292 }
293
294 if (is_rfft)
295 {
296 if (in_real[2] / 2 + 1 != out_real[2])
297 {
298 msg = "Output width is expected to match input width / 2 + 1";
299 return 1;
300 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000301 }
302 else
303 {
Luke Hutton57287132023-02-06 14:54:18 +0000304 if (in_real[2] != out_real[2])
305 {
306 msg = "Input and output width don't match";
307 return 1;
308 }
309 }
310
311 return 0;
312}
313
Tai Lya4d748b2023-03-28 22:06:56 +0000314template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000315OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700316 : GraphNode(sgt_, Op_ARGMAX, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700317{
318 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000319 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
321 INIT_ATTRIBUTE(Axis);
322}
323
Tai Lya4d748b2023-03-28 22:06:56 +0000324template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700325OpArgMax<Rank, Dtype>::~OpArgMax()
326{
327 if (attribute)
328 delete attribute;
329}
330
Tai Lya4d748b2023-03-28 22:06:56 +0000331template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700332int OpArgMax<Rank, Dtype>::checkTensorAttributes()
333{
334 if (validateRequiredOperands())
335 return 1;
336
Kevin Chengcc61be32021-10-14 17:09:57 -0700337 if (validateRequiredRank(inputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700338 {
339 return 1;
340 }
341
Kevin Chengcc61be32021-10-14 17:09:57 -0700342 int32_t output_rank = inputs[0]->getRank() - 1;
343 if (output_rank != outputs[0]->getRank())
344 {
345 printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
346 return 1;
347 }
348
Tai Lya4d748b2023-03-28 22:06:56 +0000349 if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32)
Kevin Chengcc61be32021-10-14 17:09:57 -0700350 {
351 printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
352 return 1;
353 }
354
Eric Kunzee5e26762020-10-13 16:11:07 -0700355 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
356 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
357
Kevin Chengcc61be32021-10-14 17:09:57 -0700358 if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
359 {
360 printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
361 return 1;
362 }
363
364 bool shape_check = true;
365 for (int32_t i = 0; i < input->getRank(); i++)
366 {
367 if (i < attribute->axis())
368 {
369 if (input->getShape()[i] != output->getShape()[i])
370 {
371 shape_check = false;
372 break;
373 }
374 }
375 else if (i > attribute->axis())
376 {
377 if (input->getShape()[i] != output->getShape()[i - 1])
378 {
379 shape_check = false;
380 break;
381 }
382 }
383 // No need to check i == axis
384 }
385 if (!shape_check)
386 {
387 printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
388 return 1;
389 }
390
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 return 0;
392}
393
Tai Lya4d748b2023-03-28 22:06:56 +0000394template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700395int OpArgMax<Rank, Dtype>::eval()
396{
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000397 // Check Tosa Level
398 auto tosa_level = g_func_config.tosa_level;
399 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
400
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
402
403 this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
404
405 return GraphNode::eval();
406}
407
Tai Lya4d748b2023-03-28 22:06:56 +0000408template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000409OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700410 : GraphNode(sgt_, Op_AVG_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700411{
412 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000413 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700414
Kevin Cheng93a16282021-08-31 16:14:03 -0700415 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -0700416}
417
Tai Lya4d748b2023-03-28 22:06:56 +0000418template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100419OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700420{
421 if (attribute)
422 delete attribute;
423}
424
Tai Lya4d748b2023-03-28 22:06:56 +0000425template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100426int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700427{
428 if (validateRequiredOperands())
429 return 1;
430
431 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
432 {
433 return 1;
434 }
435
436 if (inputs[0]->matchType(*outputs[0]))
437 {
438 printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
439 return 1;
440 }
441
442 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
443 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
444
Tai Lya4d748b2023-03-28 22:06:56 +0000445 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
446 "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
447 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0,
448 "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
Eric Kunzee5e26762020-10-13 16:11:07 -0700449
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000450 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000451 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700452 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000453 msg = "OpAvgPool2d: " + msg;
454 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return 1;
456 }
457
458 return 0;
459}
460
Eric Kunze830add42022-01-25 22:56:46 -0800461// This calculates the number of padding elements used for each location along an axis
462// Average pooling only divides by the number of elements used, not including padding.
463// This function uses left/right, but is also used for vertical padding with top/bottom
Tai Lya4d748b2023-03-28 22:06:56 +0000464template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
465ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(
466 int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
Eric Kunzee5e26762020-10-13 16:11:07 -0700467{
468 ETensor1<int32_t> result(out_size);
469
Eric Kunzee5e26762020-10-13 16:11:07 -0700470 result.setConstant(kernel_size);
471
Eric Kunze830add42022-01-25 22:56:46 -0800472 // adjust divisors on the left side for padding
473 // We start at the leftmost output element, and remove pad_left - (index * stride) elements
474 // until we have no more padding being used
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000475 for (int index = 0; (index <= pad_left / stride) && (index < out_size); index++)
476 {
Eric Kunze830add42022-01-25 22:56:46 -0800477 int32_t adjust = pad_left - (index * stride);
478 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 }
480
Eric Kunze830add42022-01-25 22:56:46 -0800481 // The process repeats on the right side. Padding starts taking effect as we
482 // near the rightmost input element. The first output element which touches
483 // padding is defined in the initialization of index below. Then we keep moving
484 // to the right, increasing padding until we get to the last output element.
485 int index = std::max(0, ((pad_left + in_size - kernel_size) / stride) + 1);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000486 for (; index < out_size; index++)
487 {
Eric Kunze830add42022-01-25 22:56:46 -0800488 int32_t adjust = ((index * stride) + kernel_size) - (pad_left + in_size);
489 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700490 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700491 return result;
492}
493
494// assuming input and output tensor have same scales like tflite reference
495// so no need to scale input and output
Tai Lya4d748b2023-03-28 22:06:56 +0000496template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100497int OpAvgPool2d<Dtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700498{
499 int in_batch = this->in->getShape()[0];
500 int in_height = this->in->getShape()[1];
501 int in_width = this->in->getShape()[2];
502 int in_channels = this->in->getShape()[3];
503
504 int out_batch = this->out->getShape()[0];
505 int out_height = this->out->getShape()[1];
506 int out_width = this->out->getShape()[2];
507 int out_channels = this->out->getShape()[3];
508
Kevin Chengacb550f2021-06-29 15:32:19 -0700509 ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
510 ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700511
TatWai Chong86c403b2022-06-06 20:46:01 -0700512 int pad_top = this->attribute->pad()[0];
513 int pad_bottom = this->attribute->pad()[1];
514 int pad_left = this->attribute->pad()[2];
515 int pad_right = this->attribute->pad()[3];
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000516 int kernel_y = this->attribute->kernel()[0];
517 int kernel_x = this->attribute->kernel()[1];
518 int stride_y = this->attribute->stride()[0];
519 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000520
521 // Check Tosa Level
522 auto tosa_level = g_func_config.tosa_level;
523 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
524 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
525 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
526 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
527 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
528 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
529 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
530 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -0700531
Tai Lya4d748b2023-03-28 22:06:56 +0000532 TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype());
James Ward8b390432022-08-12 20:48:56 +0100533
Eric Kunzee5e26762020-10-13 16:11:07 -0700534 DEBUG_INFO(OP,
535 "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
James Ward8b390432022-08-12 20:48:56 +0100536 "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
Jerry Gea793f462023-04-11 00:05:02 +0000537 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
538 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
TatWai Chong86c403b2022-06-06 20:46:01 -0700540 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
541 pad[0] = std::make_pair(0, 0);
542 pad[1] = std::make_pair(pad_top, pad_bottom);
543 pad[2] = std::make_pair(pad_left, pad_right);
544 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700545
546 ETensor4<InEigenType> input_val = this->in->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000547 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700548 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000549 input_val = input_val - (InEigenType)attribute->input_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700550 }
551
TatWai Chong86c403b2022-06-06 20:46:01 -0700552 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
Tai Ly307392a2023-05-12 21:42:19 +0000554 if (g_func_config.abs_mode)
555 {
556 // in abs_mode: take abs values of input_padded
557 input_padded = input_padded.abs();
558 }
559
Eric Kunzee5e26762020-10-13 16:11:07 -0700560 // assuming input and output have same scales
561 // so input and output scaling is not required
562 // TODO: check if this assumption TOSA made
563
James Ward5a9e0cd2023-10-09 16:51:26 +0000564 ETensor4<OutEigenType> out_tens(out_batch, out_height, out_width, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700565
566 // sum pool
James Ward5a9e0cd2023-10-09 16:51:26 +0000567 for (int ob = 0; ob < out_batch; ++ob)
Eric Kunzee5e26762020-10-13 16:11:07 -0700568 {
James Ward5a9e0cd2023-10-09 16:51:26 +0000569 for (int oh = 0; oh < out_height; ++oh)
Eric Kunzee5e26762020-10-13 16:11:07 -0700570 {
James Ward5a9e0cd2023-10-09 16:51:26 +0000571 for (int ow = 0; ow < out_width; ++ow)
572 {
573 for (int oc = 0; oc < out_channels; ++oc)
574 {
575 AccEigenType acc(0);
576 int filter_count = 0;
577 const int iy = oh * stride_y - pad_top;
578 const int ix = ow * stride_x - pad_left;
579 for (int ky = 0; ky < kernel_y; ++ky)
580 {
581 for (int kx = 0; kx < kernel_x; ++kx)
582 {
583 const int y = iy + ky;
584 const int x = ix + kx;
585 if ((0 <= y && y < in_height) && (0 <= x && x < in_width))
586 {
587 ++filter_count;
588 acc = acc + (AccEigenType)input_padded(ob, y, x, oc);
589 }
590 }
591 }
592 if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
593 Dtype != TOSA_REF_TYPE_FP64)
594 {
595 try
596 {
597 int32_t multiplier, shift;
598 OutEigenType out;
599 TosaReference::QuantUtil::reciprocal_scale(filter_count, multiplier, shift);
600
601 out = (OutEigenType)TosaReference::QuantUtil::apply_scale_32(acc, multiplier, shift, false);
602 out = out + (OutEigenType)(attribute->output_zp());
603 out = std::max(out, (OutEigenType)QMin);
604 out_tens(ob, oh, ow, oc) = std::min(out, (OutEigenType)QMax);
605 }
606 catch (std::string desc)
607 {
608 REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
609 }
610 }
611 else
612 {
613 REQUIRE(filter_count != 0, "OpAvgPool2d number of filters should be non-zero.");
614 out_tens(ob, oh, ow, oc) = acc / static_cast<OutEigenType>(filter_count);
615 }
616 }
617 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700618 }
619 }
James Ward5a9e0cd2023-10-09 16:51:26 +0000620 this->out->getTensor() = out_tens;
Eric Kunzee5e26762020-10-13 16:11:07 -0700621 return GraphNode::eval();
622}
623
Tai Lya4d748b2023-03-28 22:06:56 +0000624template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000625OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700626 : GraphNode(sgt_, Op_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700627{
628 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000629 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700630
Kevin Cheng93a16282021-08-31 16:14:03 -0700631 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700632}
633
Tai Lya4d748b2023-03-28 22:06:56 +0000634template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000635OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700636{
637 if (attribute)
638 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700639}
640
Tai Lya4d748b2023-03-28 22:06:56 +0000641template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000642int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700643{
644 if (validateRequiredOperands())
645 return 1;
646
647 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
648 {
649 return 1;
650 }
651
652 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
653 if (inputs[2]->getRank() != 1)
654 {
655 printNodeValidationError("OpConv2d: bias tensor must be rank 1");
656 }
657
James Wardd34b3fc2023-01-18 14:51:25 +0000658 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000659 "OpConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700660
Eric Kunzee5e26762020-10-13 16:11:07 -0700661 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
662 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
663 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100664 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700665
Kevin Cheng9fe17242021-11-10 01:04:39 +0000666 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000667 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000668 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700669 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000670 msg = "OpConv2d: " + msg;
671 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700672 return 1;
673 }
674
Eric Kunzee5e26762020-10-13 16:11:07 -0700675 return 0;
676}
677
Tai Lya4d748b2023-03-28 22:06:56 +0000678template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000679int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700680{
681 int in_batch = this->input->getShape()[0];
682 int in_height = this->input->getShape()[1];
683 int in_width = this->input->getShape()[2];
684 int in_channels = this->input->getShape()[3];
685
686 int f_out_channels = this->weight->getShape()[0];
687 int f_height = this->weight->getShape()[1];
688 int f_width = this->weight->getShape()[2];
689 int f_in_channels = this->weight->getShape()[3];
690
691 int b_out_channels = this->bias->getShape()[0];
692
693 int out_batch = this->output->getShape()[0];
694 int out_height = this->output->getShape()[1];
695 int out_width = this->output->getShape()[2];
696 int out_channels = this->output->getShape()[3];
697
Kevin Chengacb550f2021-06-29 15:32:19 -0700698 ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
699 ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
700 in_channels);
701 ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
702 out_channels);
Tai Lya641dd52023-08-11 19:58:50 +0000703 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv2d: bias channel mismatch %d != %d",
704 b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700705
TatWai Chong86c403b2022-06-06 20:46:01 -0700706 int pad_top = this->attribute->pad()[0];
707 int pad_bottom = this->attribute->pad()[1];
708 int pad_left = this->attribute->pad()[2];
709 int pad_right = this->attribute->pad()[3];
710
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000711 int stride_y = this->attribute->stride()[0];
712 int stride_x = this->attribute->stride()[1];
713 int dilation_y = this->attribute->dilation()[0];
714 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000715
716 // Check Tosa Level
717 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000718 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
719 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
720 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
721 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000722 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
723 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
724 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
725 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
726 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
727 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -0700728
729 DEBUG_INFO(OP,
730 "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000731 "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -0700732 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000733 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
734 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700735
736 // GEMM-conv2d, left matrix is input, right matrix is weight
737 Eigen::array<Eigen::Index, 2> im2col_input_dims;
738 im2col_input_dims[0] = out_batch * out_height * out_width;
739 im2col_input_dims[1] = f_height * f_width * f_in_channels;
740
741 Eigen::array<Eigen::Index, 2> im2col_weight_dims;
742 im2col_weight_dims[0] = f_height * f_width * f_in_channels;
743 im2col_weight_dims[1] = f_out_channels;
744
745 Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
746 bias_reshaped_dims[0] = 1;
747 bias_reshaped_dims[1] = b_out_channels;
748
749 Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
750 weight_zp_bcast_dims[0] = f_height;
751 weight_zp_bcast_dims[1] = f_width;
752 weight_zp_bcast_dims[2] = f_in_channels;
753
754 Eigen::array<Eigen::Index, 2> bias_bcast_dims;
755 bias_bcast_dims[0] = out_batch * out_height * out_width;
Tai Lya641dd52023-08-11 19:58:50 +0000756 bias_bcast_dims[1] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700757
758 Eigen::array<Eigen::Index, 4> col2im_output_dims;
759 col2im_output_dims[0] = out_batch;
760 col2im_output_dims[1] = out_height;
761 col2im_output_dims[2] = out_width;
762 col2im_output_dims[3] = out_channels;
763
764 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
765
TatWai Chong86c403b2022-06-06 20:46:01 -0700766 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
767 pad[0] = std::make_pair(0, 0);
768 pad[1] = std::make_pair(pad_top, pad_bottom);
769 pad[2] = std::make_pair(pad_left, pad_right);
770 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700771
772 TIn input_val = this->input->getTensor();
773 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000774 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700775 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000776 input_val = input_val - (InEigenType)attribute->input_zp();
777 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700778 }
779
TatWai Chong86c403b2022-06-06 20:46:01 -0700780 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700781
Tai Ly307392a2023-05-12 21:42:19 +0000782 TBias bias_val = this->bias->getTensor();
783
784 if (g_func_config.abs_mode)
785 {
786 // in abs_mode: take abs values of conv operands
787 input_padded = input_padded.abs();
788 weight_val = weight_val.abs();
789 bias_val = bias_val.abs();
790 }
791
Eric Kunzee5e26762020-10-13 16:11:07 -0700792 // extract_image_patches() output [N, KH, KW, H * W, C]
793 // need to transpose to [N, H * W, KH, KW, C]
794 ETensor5<InEigenType> input_extract_patches =
795 input_padded
Jerry Gea793f462023-04-11 00:05:02 +0000796 .extract_image_patches(f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700797 .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
798
799 // reshape input to [N * H * W, KH * KW * C]
800 ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
801
802 // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
803 ETensor2<WeightEigenType> im2col_weight =
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000804 weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
Eric Kunzee5e26762020-10-13 16:11:07 -0700805
806 // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
807 // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
Tai Ly307392a2023-05-12 21:42:19 +0000808 ETensor2<OutEigenType> bias_2d =
809 (bias_val.reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700810
811 // output matrix is [N * H * W, C]
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000812 ETensor2<OutEigenType> contracted_result = (im2col_input.template cast<AccEigenType>().contract(
813 im2col_weight.template cast<AccEigenType>(), contract_dims))
814 .template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700815
816 // adding bias
James Ward8b390432022-08-12 20:48:56 +0100817 ETensor2<OutEigenType> biased_output = contracted_result + bias_2d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700818
819 // reshape back to [N, H, W, C]
820 this->output->getTensor() = biased_output.reshape(col2im_output_dims);
821
Tai Lya4d748b2023-03-28 22:06:56 +0000822 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -0700823 {
James Ward8b390432022-08-12 20:48:56 +0100824 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
825 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700826 }
827
828 return GraphNode::eval();
829}
830
Tai Lya4d748b2023-03-28 22:06:56 +0000831template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000832OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Cheng1533b852021-09-01 12:51:58 -0700833 : GraphNode(sgt_, Op_CONV3D, id_)
834{
835 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000836 setRequiredRank(5, 5);
Kevin Cheng1533b852021-09-01 12:51:58 -0700837
838 INIT_ATTRIBUTE(Conv);
Kevin Cheng1533b852021-09-01 12:51:58 -0700839}
840
Tai Lya4d748b2023-03-28 22:06:56 +0000841template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000842OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
Kevin Cheng1533b852021-09-01 12:51:58 -0700843{
844 if (attribute)
845 delete attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700846}
847
Tai Lya4d748b2023-03-28 22:06:56 +0000848template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000849int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Kevin Cheng1533b852021-09-01 12:51:58 -0700850{
851 if (validateRequiredOperands())
852 return 1;
853
854 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
855 {
856 return 1;
857 }
858
859 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
860 if (inputs[2]->getRank() != 1)
861 {
862 printNodeValidationError("OpConv3d: bias tensor must be rank 1");
863 }
864
James Wardd34b3fc2023-01-18 14:51:25 +0000865 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000866 "OpConv3d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700867
Kevin Cheng1533b852021-09-01 12:51:58 -0700868 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
869 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
870 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100871 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng1533b852021-09-01 12:51:58 -0700872
Kevin Cheng9fe17242021-11-10 01:04:39 +0000873 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000874 if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000875 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Kevin Cheng1533b852021-09-01 12:51:58 -0700876 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000877 msg = "OpConv3d: " + msg;
878 printNodeValidationError(msg.c_str());
Kevin Cheng1533b852021-09-01 12:51:58 -0700879 return 1;
880 }
881
Kevin Cheng1533b852021-09-01 12:51:58 -0700882 return 0;
883}
884
Tai Lya4d748b2023-03-28 22:06:56 +0000885template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000886int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
Kevin Cheng1533b852021-09-01 12:51:58 -0700887{
888 int in_batch = this->input->getShape()[0];
889 int in_depth = this->input->getShape()[1];
890 int in_height = this->input->getShape()[2];
891 int in_width = this->input->getShape()[3];
892 int in_channels = this->input->getShape()[4];
893
894 int f_out_channels = this->weight->getShape()[0];
895 int f_depth = this->weight->getShape()[1];
896 int f_height = this->weight->getShape()[2];
897 int f_width = this->weight->getShape()[3];
898 int f_in_channels = this->weight->getShape()[4];
899
900 int b_out_channels = this->bias->getShape()[0];
901
902 int out_batch = this->output->getShape()[0];
903 int out_depth = this->output->getShape()[1];
904 int out_height = this->output->getShape()[2];
905 int out_width = this->output->getShape()[3];
906 int out_channels = this->output->getShape()[4];
907
908 ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
909 ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
910 in_channels);
911 ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
912 out_channels);
Tai Lya641dd52023-08-11 19:58:50 +0000913 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv3d: bias channel mismatch %d != %d",
914 b_out_channels, out_channels);
Kevin Cheng1533b852021-09-01 12:51:58 -0700915
TatWai Chong86c403b2022-06-06 20:46:01 -0700916 int pad_d0 = this->attribute->pad()[0];
917 int pad_d1 = this->attribute->pad()[1];
918 int pad_top = this->attribute->pad()[2];
919 int pad_bottom = this->attribute->pad()[3];
920 int pad_left = this->attribute->pad()[4];
921 int pad_right = this->attribute->pad()[5];
922
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000923 int stride_d = this->attribute->stride()[0];
924 int stride_y = this->attribute->stride()[1];
925 int stride_x = this->attribute->stride()[2];
TatWai Chong86c403b2022-06-06 20:46:01 -0700926
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000927 int dilation_d = this->attribute->dilation()[0];
928 int dilation_y = this->attribute->dilation()[1];
929 int dilation_x = this->attribute->dilation()[2];
Jerry Gea793f462023-04-11 00:05:02 +0000930
931 // Check Tosa Level
932 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000933 LEVEL_CHECK(dilation_d * f_depth <= tosa_level.MAX_KERNEL,
934 "dilation_d * KD should be smaller than or equal to MAX_KERNEL");
935 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
936 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
937 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
938 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000939 LEVEL_CHECK(pad_d0 <= tosa_level.MAX_KERNEL, "pad_d0 should be smaller than or equal to MAX_KERNEL");
940 LEVEL_CHECK(pad_d1 <= tosa_level.MAX_KERNEL, "pad_d1 should be smaller than or equal to MAX_KERNEL");
941 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
942 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
943 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
944 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
945 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
946 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
947 LEVEL_CHECK(stride_d <= tosa_level.MAX_STRIDE, "stride_d should be smaller than or equal to MAX_STRIDE");
Kevin Cheng1533b852021-09-01 12:51:58 -0700948
949 DEBUG_INFO(
950 OP,
951 "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000952 "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
Kevin Cheng1533b852021-09-01 12:51:58 -0700953 in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
Jerry Gea793f462023-04-11 00:05:02 +0000954 out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_y, stride_x, dilation_d, dilation_y,
955 dilation_x, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
Kevin Cheng1533b852021-09-01 12:51:58 -0700956
TatWai Chong86c403b2022-06-06 20:46:01 -0700957 Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
958 pad[0] = std::make_pair(0, 0);
959 pad[1] = std::make_pair(pad_d0, pad_d1);
960 pad[2] = std::make_pair(pad_top, pad_bottom);
961 pad[3] = std::make_pair(pad_left, pad_right);
962 pad[4] = std::make_pair(0, 0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700963
964 TIn input_val = this->input->getTensor();
965 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000966 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Kevin Cheng1533b852021-09-01 12:51:58 -0700967 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000968 input_val = input_val - (InEigenType)attribute->input_zp();
969 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Kevin Cheng1533b852021-09-01 12:51:58 -0700970 }
971
TatWai Chong86c403b2022-06-06 20:46:01 -0700972 ETensor5<InEigenType> input_padded = input_val.pad(pad);
Kevin Cheng1533b852021-09-01 12:51:58 -0700973
Tai Ly307392a2023-05-12 21:42:19 +0000974 TBias bias_val = this->bias->getTensor();
975
976 if (g_func_config.abs_mode)
977 {
978 // in abs_mode: take abs values of conv operands
979 input_padded = input_padded.abs();
980 weight_val = weight_val.abs();
981 bias_val = bias_val.abs();
982 }
983
Kevin Cheng1533b852021-09-01 12:51:58 -0700984 // 1. initialize with bias
985 Eigen::array<Eigen::Index, 5> reshape_dim;
986 reshape_dim.fill(1);
987 reshape_dim[4] = b_out_channels;
988
989 Eigen::array<Eigen::Index, 5> bcast;
990 bcast[0] = out_batch;
991 bcast[1] = out_depth;
992 bcast[2] = out_height;
993 bcast[3] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +0000994 bcast[4] = (b_out_channels == 1) ? out_channels : 1;
Tai Ly307392a2023-05-12 21:42:19 +0000995 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Kevin Cheng1533b852021-09-01 12:51:58 -0700996
997 // 2. direct convolution
James Ward8b390432022-08-12 20:48:56 +0100998 AccEigenType acc(0.0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700999 int d_idx, h_idx, w_idx;
1000
1001 for (int ob = 0; ob < out_batch; ob++)
1002 {
1003 for (int od = 0; od < out_depth; od++)
1004 {
1005 for (int oh = 0; oh < out_height; oh++)
1006 {
1007 for (int ow = 0; ow < out_width; ow++)
1008 {
1009 for (int oc = 0; oc < out_channels; oc++)
1010 {
Eric Kunze7edb34c2022-05-16 17:34:40 -07001011 // Initialize accumulator with bias value
James Ward8b390432022-08-12 20:48:56 +01001012 acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc);
Kevin Cheng1533b852021-09-01 12:51:58 -07001013 for (int fd = 0; fd < f_depth; fd++)
1014 {
1015 d_idx = od * stride_d + fd * dilation_d;
1016 for (int fh = 0; fh < f_height; fh++)
1017 {
Jerry Gea793f462023-04-11 00:05:02 +00001018 h_idx = oh * stride_y + fh * dilation_y;
Kevin Cheng1533b852021-09-01 12:51:58 -07001019 for (int fw = 0; fw < f_width; fw++)
1020 {
Jerry Gea793f462023-04-11 00:05:02 +00001021 w_idx = ow * stride_x + fw * dilation_x;
Kevin Cheng1533b852021-09-01 12:51:58 -07001022 for (int ic = 0; ic < in_channels; ic++)
1023 {
1024 acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
1025 (AccEigenType)weight_val(oc, fd, fh, fw, ic));
1026 }
1027 }
1028 }
1029 }
James Ward8b390432022-08-12 20:48:56 +01001030 this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc;
Kevin Cheng1533b852021-09-01 12:51:58 -07001031 }
1032 }
1033 }
1034 }
1035 }
1036
Tai Lya4d748b2023-03-28 22:06:56 +00001037 if (OutDtype == TOSA_REF_TYPE_INT48)
Kevin Cheng1533b852021-09-01 12:51:58 -07001038 {
James Ward8b390432022-08-12 20:48:56 +01001039 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1040 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Kevin Cheng1533b852021-09-01 12:51:58 -07001041 }
1042
1043 return GraphNode::eval();
1044}
1045
Tai Lya4d748b2023-03-28 22:06:56 +00001046template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001047OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001048 TosaAttributeBase* attribute_,
1049 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001050 : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001051{
1052 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001053 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001054
Kevin Cheng93a16282021-08-31 16:14:03 -07001055 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001056}
1057
Tai Lya4d748b2023-03-28 22:06:56 +00001058template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001059OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001060{
1061 if (attribute)
1062 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001063}
1064
Tai Lya4d748b2023-03-28 22:06:56 +00001065template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001066int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001067{
1068 if (validateRequiredOperands())
1069 return 1;
1070
1071 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1072 {
1073 return 1;
1074 }
1075
1076 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
1077 if (inputs[2]->getRank() != 1)
1078 {
1079 printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
1080 }
1081
James Wardd34b3fc2023-01-18 14:51:25 +00001082 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001083 "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001084
Eric Kunzee5e26762020-10-13 16:11:07 -07001085 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1086 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1087 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001088 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001089
Kevin Cheng9fe17242021-11-10 01:04:39 +00001090 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001091 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001092 weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001093 {
Kevin Cheng9fe17242021-11-10 01:04:39 +00001094 msg = "OpDepthwiseConv2d: " + msg;
1095 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001096 return 1;
1097 }
1098
Eric Kunzee5e26762020-10-13 16:11:07 -07001099 return 0;
1100}
1101
Tai Lya4d748b2023-03-28 22:06:56 +00001102template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001103int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001104{
1105 int in_batch = this->input->getShape()[0];
1106 int in_height = this->input->getShape()[1];
1107 int in_width = this->input->getShape()[2];
1108 int in_channels = this->input->getShape()[3];
1109
1110 int f_height = this->weight->getShape()[0];
1111 int f_width = this->weight->getShape()[1];
1112 int f_in_channels = this->weight->getShape()[2];
1113 int f_multiplier = this->weight->getShape()[3];
1114
1115 int b_out_channels = this->bias->getShape()[0];
1116
1117 int out_batch = this->output->getShape()[0];
1118 int out_height = this->output->getShape()[1];
1119 int out_width = this->output->getShape()[2];
1120 int out_channels = this->output->getShape()[3];
1121
Kevin Chengacb550f2021-06-29 15:32:19 -07001122 ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1123 ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1124 in_channels);
1125 ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
1126 in_channels * f_multiplier, out_channels);
Tai Lya641dd52023-08-11 19:58:50 +00001127 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
1128 "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001129
TatWai Chong86c403b2022-06-06 20:46:01 -07001130 int pad_top = this->attribute->pad()[0];
1131 int pad_bottom = this->attribute->pad()[1];
1132 int pad_left = this->attribute->pad()[2];
1133 int pad_right = this->attribute->pad()[3];
1134
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001135 int stride_y = this->attribute->stride()[0];
1136 int stride_x = this->attribute->stride()[1];
1137 int dilation_y = this->attribute->dilation()[0];
1138 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001139
1140 // Check Tosa Level
1141 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001142 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
1143 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
1144 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
1145 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00001146 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1147 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1148 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1149 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
1150 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1151 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -07001152
1153 DEBUG_INFO(OP,
1154 "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00001155 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -07001156 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001157 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
1158 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001159
TatWai Chong86c403b2022-06-06 20:46:01 -07001160 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1161 pad[0] = std::make_pair(0, 0);
1162 pad[1] = std::make_pair(pad_top, pad_bottom);
1163 pad[2] = std::make_pair(pad_left, pad_right);
1164 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001165
1166 TIn input_val = this->input->getTensor();
1167 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001168 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001169 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001170 input_val = input_val - (InEigenType)attribute->input_zp();
1171 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001172 }
1173
TatWai Chong86c403b2022-06-06 20:46:01 -07001174 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -07001175
Tai Ly307392a2023-05-12 21:42:19 +00001176 TBias bias_val = this->bias->getTensor();
1177
1178 if (g_func_config.abs_mode)
1179 {
1180 // in abs_mode: take abs values of conv operands
1181 input_padded = input_padded.abs();
1182 weight_val = weight_val.abs();
1183 bias_val = bias_val.abs();
1184 }
1185
Eric Kunzee5e26762020-10-13 16:11:07 -07001186 // GEMM doesn't fit well with DepthwiseConv2d
TatWai Chong86c403b2022-06-06 20:46:01 -07001187 // 1. use extract_image_patches() to handle stride/dilation/pad
Eric Kunzee5e26762020-10-13 16:11:07 -07001188 // 2. perform direct convolution
1189
1190 // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
1191 ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
Jerry Gea793f462023-04-11 00:05:02 +00001192 f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID);
Eric Kunzee5e26762020-10-13 16:11:07 -07001193
1194 Eigen::array<Eigen::Index, 4> reshape_dim;
1195 reshape_dim.fill(1);
1196 reshape_dim[3] = b_out_channels;
1197
1198 Eigen::array<Eigen::Index, 4> bcast;
1199 bcast[0] = out_batch;
1200 bcast[1] = out_height;
1201 bcast[2] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +00001202 bcast[3] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07001203
1204 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00001205 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001206
1207 // 2. direct depthwise convolution
1208 for (int ob = 0; ob < out_batch; ob++)
1209 {
1210 for (int oh = 0; oh < out_height; oh++)
1211 {
1212 for (int ow = 0; ow < out_width; ow++)
1213 {
1214 for (int ic = 0; ic < in_channels; ic++)
1215 {
1216 for (int cm = 0; cm < f_multiplier; cm++)
1217 {
1218 for (int fh = 0; fh < f_height; fh++)
1219 {
1220 for (int fw = 0; fw < f_width; fw++)
1221 {
James Ward8b390432022-08-12 20:48:56 +01001222 // Perform multiplication in AccEigenType then cast to OutEigenType
Eric Kunzebe2e87c2023-08-07 15:16:18 +00001223 this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
1224 (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh,
1225 ic) *
1226 (AccEigenType)weight_val(fh, fw, ic, cm));
Eric Kunzee5e26762020-10-13 16:11:07 -07001227 }
1228 }
1229 }
1230 }
1231 }
1232 }
1233 }
1234
Tai Lya4d748b2023-03-28 22:06:56 +00001235 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001236 {
James Ward8b390432022-08-12 20:48:56 +01001237 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1238 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001239 }
1240
1241 return GraphNode::eval();
1242}
1243
Tai Lya4d748b2023-03-28 22:06:56 +00001244template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001245OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001246 TosaAttributeBase* attribute_,
1247 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001248 : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001249{
1250 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001251 setRequiredRank(2, 2);
Eric Kunzee5e26762020-10-13 16:11:07 -07001252
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001253 INIT_ATTRIBUTE(FullyConnected);
Eric Kunzee5e26762020-10-13 16:11:07 -07001254}
1255
Tai Lya4d748b2023-03-28 22:06:56 +00001256template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001257OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
Eric Kunzee5e26762020-10-13 16:11:07 -07001258{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001259 if (attribute)
1260 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001261}
1262
Tai Lya4d748b2023-03-28 22:06:56 +00001263template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001264int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001265{
1266 if (validateRequiredOperands())
1267 return 1;
1268
1269 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1270 {
1271 return 1;
1272 }
1273
1274 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1275 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1276 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
1277
1278 if (input->getShape()[1] != weight->getShape()[1])
1279 {
1280 printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
1281 return 1;
1282 }
1283
1284 if (weight->getShape()[0] != bias->getShape()[0])
1285 {
1286 printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
1287 return 1;
1288 }
1289
James Wardd34b3fc2023-01-18 14:51:25 +00001290 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001291 "OpFullyConnected: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001292
James Ward8b390432022-08-12 20:48:56 +01001293 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001294
Tai Lya4d748b2023-03-28 22:06:56 +00001295 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
1296 "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
1297 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
1298 "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001299
Eric Kunzee5e26762020-10-13 16:11:07 -07001300 return 0;
1301}
1302
Tai Lya4d748b2023-03-28 22:06:56 +00001303template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001304int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001305{
1306 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1307 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1308
1309 Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
1310
Tai Lya641dd52023-08-11 19:58:50 +00001311 int b_out_channels = this->bias->getShape()[0];
1312 int out_channels = this->output->getShape()[1];
1313
1314 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpFullyConnected: bias channels mismatch %d != %d",
1315 b_out_channels, out_channels);
1316
Eric Kunzee5e26762020-10-13 16:11:07 -07001317 Eigen::array<Eigen::Index, 2> bias_reshape;
1318 bias_reshape[0] = 1;
Tai Lya641dd52023-08-11 19:58:50 +00001319 bias_reshape[1] = b_out_channels;
Eric Kunzee5e26762020-10-13 16:11:07 -07001320
1321 Eigen::array<Eigen::Index, 2> bias_bcast;
1322 bias_bcast[0] = this->input->getShape()[0];
Tai Lya641dd52023-08-11 19:58:50 +00001323 bias_bcast[1] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07001324
1325 TIn input_val = this->input->getTensor();
1326 TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
Tai Lya4d748b2023-03-28 22:06:56 +00001327 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001328 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001329 input_val = input_val - (InEigenType)attribute->input_zp();
1330 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001331 }
1332
Tai Ly307392a2023-05-12 21:42:19 +00001333 TBias bias_val = this->bias->getTensor();
1334
1335 if (g_func_config.abs_mode)
1336 {
1337 // in abs_mode: take abs values of conv operands
1338 input_val = input_val.abs();
1339 weight_val = weight_val.abs();
1340 bias_val = bias_val.abs();
1341 }
1342
1343 this->output->getTensor() = input_val.template cast<AccEigenType>()
1344 .contract(weight_val.template cast<AccEigenType>(), dims)
1345 .template cast<OutEigenType>() +
1346 bias_val.reshape(bias_reshape).broadcast(bias_bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001347
Tai Lya4d748b2023-03-28 22:06:56 +00001348 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001349 {
James Ward8b390432022-08-12 20:48:56 +01001350 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1351 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001352 }
1353 return GraphNode::eval();
1354}
1355
Tai Lya4d748b2023-03-28 22:06:56 +00001356template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001357OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001358 : GraphNode(sgt_, Op_MATMUL, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001359{
1360 setRequiredOperands(2, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001361 setRequiredRank(3, 3);
Eric Kunzee5e26762020-10-13 16:11:07 -07001362
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001363 INIT_ATTRIBUTE(MatMul);
Eric Kunzee5e26762020-10-13 16:11:07 -07001364}
1365
Tai Lya4d748b2023-03-28 22:06:56 +00001366template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001367OpMatMul<Dtype, OutDtype>::~OpMatMul()
Eric Kunzee5e26762020-10-13 16:11:07 -07001368{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001369 if (attribute)
1370 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001371}
1372
Tai Lya4d748b2023-03-28 22:06:56 +00001373template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001374int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001375{
1376 if (validateRequiredOperands())
1377 return 1;
1378
1379 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1380 {
1381 return 1;
1382 }
1383
James Wardd34b3fc2023-01-18 14:51:25 +00001384 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001385 "OpMatMul: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001386
Kevin Cheng2d60f002021-06-09 14:18:32 -07001387 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1388 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
James Ward8b390432022-08-12 20:48:56 +01001389 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001390
Kevin Cheng2d60f002021-06-09 14:18:32 -07001391 ASSERT_MEM(a && b && output);
1392
1393 // a: [N, H, C]
1394 // b: [N, C, W]
1395 // c: [N, H, W]
1396
1397 // Check N
1398 if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001399 {
Kevin Cheng2d60f002021-06-09 14:18:32 -07001400 printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
Eric Kunzee5e26762020-10-13 16:11:07 -07001401 return 1;
1402 }
Kevin Cheng2d60f002021-06-09 14:18:32 -07001403 N = a->getShape()[0];
Eric Kunzee5e26762020-10-13 16:11:07 -07001404
Kevin Cheng2d60f002021-06-09 14:18:32 -07001405 // Check C
1406 if (a->getShape()[2] != b->getShape()[1])
1407 {
1408 printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
1409 return 1;
1410 }
1411 C = a->getShape()[2];
1412
1413 // Check H
1414 if (a->getShape()[1] != output->getShape()[1])
1415 {
1416 printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
1417 return 1;
1418 }
1419 H = a->getShape()[1];
1420
1421 // Check W
1422 if (b->getShape()[2] != output->getShape()[2])
1423 {
1424 printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
1425 return 1;
1426 }
1427 W = b->getShape()[2];
Eric Kunzee5e26762020-10-13 16:11:07 -07001428
Tai Lya4d748b2023-03-28 22:06:56 +00001429 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0,
1430 "OpMatMul: A zeropoint must be zero for non int8_t data");
1431 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0,
1432 "OpMatMul: B zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001433
Eric Kunzee5e26762020-10-13 16:11:07 -07001434 return 0;
1435}
1436
Tai Lya4d748b2023-03-28 22:06:56 +00001437template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001438int OpMatMul<Dtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001439{
1440 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1441 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1442
1443 TIn a_val = this->a->getTensor();
1444 TIn b_val = this->b->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001445 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001446 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001447 a_val = a_val - (InEigenType)attribute->a_zp();
1448 b_val = b_val - (InEigenType)attribute->b_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001449 }
1450
Tai Ly307392a2023-05-12 21:42:19 +00001451 if (g_func_config.abs_mode)
1452 {
1453 // in abs_mode: take abs values of matmul operands
1454 a_val = a_val.abs();
1455 b_val = b_val.abs();
1456 }
1457
Kevin Cheng2d60f002021-06-09 14:18:32 -07001458 Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
1459 Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
1460 Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
1461
1462 Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
1463 Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
1464
1465 Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
1466 Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
1467
1468 // Iterate N dimension.
1469 for (int i = 0; i < N; i++)
1470 {
1471 a_begin_array[0] = i;
1472 b_begin_array[0] = i;
1473
1474 TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
1475 TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
1476 TAccRank2 output_rank2_val =
1477 a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
James Ward8b390432022-08-12 20:48:56 +01001478 TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast<OutEigenType>();
Kevin Cheng2d60f002021-06-09 14:18:32 -07001479 if (i == 0)
1480 {
1481 this->output->getTensor() = output_rank3_val;
1482 }
1483 else
1484 {
James Ward8b390432022-08-12 20:48:56 +01001485 TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001486 this->output->getTensor() = temp;
1487 }
1488 }
Eric Kunzee5e26762020-10-13 16:11:07 -07001489
Tai Lya4d748b2023-03-28 22:06:56 +00001490 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001491 {
James Ward8b390432022-08-12 20:48:56 +01001492 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1493 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001494 }
1495
1496 return GraphNode::eval();
1497}
1498
Tai Lya4d748b2023-03-28 22:06:56 +00001499template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001500OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001501 : GraphNode(sgt_, Op_MAX_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001502{
1503 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001504 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001505
Kevin Cheng93a16282021-08-31 16:14:03 -07001506 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -07001507}
1508
Tai Lya4d748b2023-03-28 22:06:56 +00001509template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001510OpMaxPool2d<Dtype>::~OpMaxPool2d()
1511{
1512 if (attribute)
1513 delete attribute;
1514}
1515
Tai Lya4d748b2023-03-28 22:06:56 +00001516template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001517int OpMaxPool2d<Dtype>::checkTensorAttributes()
1518{
1519 if (validateRequiredOperands())
1520 return 1;
1521
1522 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
1523 {
1524 return 1;
1525 }
1526
1527 if (inputs[0]->matchType(*outputs[0]))
1528 {
1529 printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
1530 return 1;
1531 }
1532
1533 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1534 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1535
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001536 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +00001537 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001538 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001539 msg = "OpMaxPool2d: " + msg;
1540 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001541 return 1;
1542 }
1543
1544 return 0;
1545}
1546
Tai Lya4d748b2023-03-28 22:06:56 +00001547template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001548int OpMaxPool2d<Dtype>::eval()
1549{
1550 int in_batch = this->in->getShape()[0];
1551 int in_height = this->in->getShape()[1];
1552 int in_width = this->in->getShape()[2];
1553 int in_channels = this->in->getShape()[3];
1554
1555 int out_batch = this->out->getShape()[0];
1556 int out_height = this->out->getShape()[1];
1557 int out_width = this->out->getShape()[2];
1558 int out_channels = this->out->getShape()[3];
1559
Kevin Chengacb550f2021-06-29 15:32:19 -07001560 ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1561 ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001562
TatWai Chong86c403b2022-06-06 20:46:01 -07001563 int pad_top = this->attribute->pad()[0];
1564 int pad_bottom = this->attribute->pad()[1];
1565 int pad_left = this->attribute->pad()[2];
1566 int pad_right = this->attribute->pad()[3];
1567
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001568 int kernel_y = this->attribute->kernel()[0];
1569 int kernel_x = this->attribute->kernel()[1];
1570 int stride_y = this->attribute->stride()[0];
1571 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001572
1573 // Check Tosa Level
1574 auto tosa_level = g_func_config.tosa_level;
1575 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
1576 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
1577 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1578 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
1579 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1580 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1581 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1582 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -07001583
1584 DEBUG_INFO(OP,
1585 "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
TatWai Chong86c403b2022-06-06 20:46:01 -07001586 "stride=[%d,%d], pad=[%d,%d,%d,%d]",
Jerry Gea793f462023-04-11 00:05:02 +00001587 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
1588 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001589
1590 Eigen::array<Eigen::Index, 2> im2col_input_dims;
Jerry Gea793f462023-04-11 00:05:02 +00001591 im2col_input_dims[0] = kernel_y * kernel_x;
Eric Kunzee5e26762020-10-13 16:11:07 -07001592 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
1593
1594 Eigen::array<Eigen::Index, 4> col2im_output_dims;
1595 col2im_output_dims[0] = out_batch;
1596 col2im_output_dims[1] = out_height;
1597 col2im_output_dims[2] = out_width;
1598 col2im_output_dims[3] = out_channels;
1599
TatWai Chong86c403b2022-06-06 20:46:01 -07001600 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1601 pad[0] = std::make_pair(0, 0);
1602 pad[1] = std::make_pair(pad_top, pad_bottom);
1603 pad[2] = std::make_pair(pad_left, pad_right);
1604 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001605
TatWai Chong86c403b2022-06-06 20:46:01 -07001606 ETensor4<InEigenType> input_padded = this->in->getTensor().pad(pad, std::numeric_limits<InEigenType>::lowest());
Eric Kunzee5e26762020-10-13 16:11:07 -07001607
1608 // extract_image_patches() output [N, KH, KW, H * W, C]
1609 // transpose to [KH, KW, N, H * W, C]
1610 // reshape to [KH * KW, N * H * W * C]
1611 //
1612 // Set the padding value to be the most negative value that can be
1613 // represented by the datatype to ensure that any padding values will be equal
1614 // to or smaller than the actual maximum in the KH x KW patch.
1615 ETensor2<InEigenType> input_extract_patches =
1616 input_padded
Jerry Gea793f462023-04-11 00:05:02 +00001617 .extract_image_patches(kernel_y, kernel_x, stride_y, stride_x, 1, 1, Eigen::PADDING_VALID,
Eric Kunzee5e26762020-10-13 16:11:07 -07001618 std::numeric_limits<InEigenType>::lowest())
1619 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
1620 .reshape(im2col_input_dims);
1621
1622 // Get the maximum of the KHxHW patches along axis 0
1623 Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
1624
1625 // 1D result with [N * H * W * C]
1626 ETensor1<OutEigenType> out_1d(this->out->getElementCount());
1627
1628 // index input_patches with argmax array should give the result
1629 for (size_t i = 0; i < this->out->getElementCount(); i++)
1630 {
1631 out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
1632 }
1633
1634 // reshape result to [N, H, W, C]
1635 this->out->getTensor() = out_1d.reshape(col2im_output_dims);
1636
1637 return GraphNode::eval();
1638}
1639
Tai Lya4d748b2023-03-28 22:06:56 +00001640template <TOSA_REF_TYPE Dtype>
1641OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton57287132023-02-06 14:54:18 +00001642 : GraphNode(sgt_, Op_FFT2D, id_)
1643{
1644 setRequiredOperands(2, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001645 setRequiredRank(3, 3);
Luke Hutton57287132023-02-06 14:54:18 +00001646
1647 INIT_ATTRIBUTE(FFT);
1648}
1649
Tai Lya4d748b2023-03-28 22:06:56 +00001650template <TOSA_REF_TYPE Dtype>
1651OpFFT2d<Dtype>::~OpFFT2d()
1652{
Luke Hutton57287132023-02-06 14:54:18 +00001653 if (attribute)
1654 delete attribute;
1655}
1656
Tai Lya4d748b2023-03-28 22:06:56 +00001657template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001658int OpFFT2d<Dtype>::checkTensorAttributes()
1659{
1660 if (validateRequiredOperands())
1661 return 1;
1662
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001663 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]) ||
1664 validateRequiredRank(outputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001665 {
1666 return 1;
1667 }
1668
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001669 if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) || inputs[0]->matchType(*inputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001670 {
1671 printNodeValidationError("OpFFT2d: input and output tensor type mismatch");
1672 return 1;
1673 }
1674
1675 in_real = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1676 in_imag = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
1677 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1678 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1679
1680 ASSERT_MEM(in_real && in_imag && out_real && out_imag);
1681
1682 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001683 if (check_fft_shape(in_real->getShape(), in_imag->getShape(), out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton57287132023-02-06 14:54:18 +00001684 {
1685 msg = "OpFFT2d: " + msg;
1686 printNodeValidationError(msg.c_str());
1687 return 1;
1688 }
1689
1690 return 0;
1691}
1692
Tai Lya4d748b2023-03-28 22:06:56 +00001693template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001694int OpFFT2d<Dtype>::eval()
1695{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001696 int in_real_batch = this->in_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001697 int in_real_height = this->in_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001698 int in_real_width = this->in_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001699
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001700 int in_imag_batch = this->in_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001701 int in_imag_height = this->in_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001702 int in_imag_width = this->in_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001703
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001704 int out_real_batch = this->out_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001705 int out_real_height = this->out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001706 int out_real_width = this->out_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001707
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001708 int out_imag_batch = this->out_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001709 int out_imag_height = this->out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001710 int out_imag_width = this->out_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001711
Jerry Gea793f462023-04-11 00:05:02 +00001712 // Check Tosa Level
1713 auto tosa_level = g_func_config.tosa_level;
1714 LEVEL_CHECK(in_real_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1715 LEVEL_CHECK(in_real_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1716
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001717 DEBUG_INFO(OP, "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]",
1718 in_real_batch, in_real_height, in_real_width, in_imag_batch, in_imag_height, in_imag_width,
1719 out_real_batch, out_real_height, out_real_width, out_imag_batch, out_imag_height, out_imag_width);
Luke Hutton57287132023-02-06 14:54:18 +00001720
1721 OutEigenType sum_real, sum_imag, a, sign_val = 1.0;
1722
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001723 if (attribute->inverse())
1724 {
Luke Hutton57287132023-02-06 14:54:18 +00001725 sign_val = -1.0;
1726 }
1727
Tai Ly307392a2023-05-12 21:42:19 +00001728 TIn in_real_val = this->in_real->getTensor();
1729 TIn in_imag_val = this->in_imag->getTensor();
1730
1731 if (g_func_config.abs_mode)
1732 {
1733 // in abs_mode: take abs values of real and imag operands
1734 in_real_val = in_real_val.abs();
1735 in_imag_val = in_imag_val.abs();
1736 }
1737
Luke Hutton57287132023-02-06 14:54:18 +00001738 for (int n = 0; n < in_real_batch; n++)
1739 {
1740 for (int oy = 0; oy < out_real_height; oy++)
1741 {
1742 for (int ox = 0; ox < out_real_width; ox++)
1743 {
1744 sum_real = 0.0;
1745 sum_imag = 0.0;
1746 for (int iy = 0; iy < in_real_height; iy++)
1747 {
1748 for (int ix = 0; ix < in_real_width; ix++)
1749 {
Tai Ly307392a2023-05-12 21:42:19 +00001750 OutEigenType val_real = in_real_val(n, iy, ix);
1751 OutEigenType val_imag = in_imag_val(n, iy, ix);
Luke Hutton57287132023-02-06 14:54:18 +00001752 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001753 a = sign_val * 2 * M_PI *
1754 ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width);
Luke Hutton57287132023-02-06 14:54:18 +00001755 sum_real += val_real * cos(a) + val_imag * sin(a);
1756 sum_imag += -val_real * sin(a) + val_imag * cos(a);
1757 }
1758 }
1759 this->out_real->getTensor()(n, oy, ox) = sum_real;
1760 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1761 }
1762 }
1763 }
1764
1765 return GraphNode::eval();
1766}
1767
Tai Lya4d748b2023-03-28 22:06:56 +00001768template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001769OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton261b7b62023-01-10 14:50:31 +00001770 : GraphNode(sgt_, Op_RFFT2D, id_)
1771{
1772 setRequiredOperands(1, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001773 setRequiredRank(3, 3);
Luke Hutton261b7b62023-01-10 14:50:31 +00001774}
1775
Tai Lya4d748b2023-03-28 22:06:56 +00001776template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001777OpRFFT2d<Dtype>::~OpRFFT2d()
1778{}
Luke Hutton261b7b62023-01-10 14:50:31 +00001779
Tai Lya4d748b2023-03-28 22:06:56 +00001780template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001781int OpRFFT2d<Dtype>::checkTensorAttributes()
1782{
1783 if (validateRequiredOperands())
1784 return 1;
1785
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001786 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1]))
Luke Hutton261b7b62023-01-10 14:50:31 +00001787 {
1788 return 1;
1789 }
1790
1791 if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
1792 {
1793 printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
1794 return 1;
1795 }
1796
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001797 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
Luke Hutton261b7b62023-01-10 14:50:31 +00001798 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1799 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1800
1801 ASSERT_MEM(in && out_real && out_imag);
1802
Luke Hutton57287132023-02-06 14:54:18 +00001803 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001804 if (check_fft_shape(in->getShape(), {}, out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton261b7b62023-01-10 14:50:31 +00001805 {
Luke Hutton57287132023-02-06 14:54:18 +00001806 msg = "OpRFFT2d: " + msg;
1807 printNodeValidationError(msg.c_str());
Luke Hutton261b7b62023-01-10 14:50:31 +00001808 return 1;
1809 }
1810
1811 return 0;
1812}
1813
Tai Lya4d748b2023-03-28 22:06:56 +00001814template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001815int OpRFFT2d<Dtype>::eval()
1816{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001817 int32_t in_batch = in->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001818 int32_t in_height = in->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001819 int32_t in_width = in->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001820
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001821 int32_t out_real_batch = out_real->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001822 int32_t out_real_height = out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001823 int32_t out_real_width = out_real->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001824
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001825 int32_t out_imag_batch = out_imag->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001826 int32_t out_imag_height = out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001827 int32_t out_imag_width = out_imag->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001828
Jerry Gea793f462023-04-11 00:05:02 +00001829 // Check Tosa Level
1830 auto tosa_level = g_func_config.tosa_level;
1831 LEVEL_CHECK(in_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1832 LEVEL_CHECK(in_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1833
Luke Hutton261b7b62023-01-10 14:50:31 +00001834 DEBUG_INFO(OP,
1835 "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], "
1836 "output_imag.shape=[%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001837 in_batch, in_height, in_width, out_real_batch, out_real_height, out_real_width, out_imag_batch,
1838 out_imag_height, out_imag_width);
Luke Hutton261b7b62023-01-10 14:50:31 +00001839
1840 OutEigenType sum_real, sum_imag, a;
1841
Tai Ly307392a2023-05-12 21:42:19 +00001842 TIn in_val = this->in->getTensor();
1843
1844 if (g_func_config.abs_mode)
1845 {
1846 // in abs_mode: take abs values of in operand
1847 in_val = in_val.abs();
1848 }
1849
Luke Hutton261b7b62023-01-10 14:50:31 +00001850 for (int n = 0; n < in_batch; n++)
1851 {
1852 for (int oy = 0; oy < out_real_height; oy++)
1853 {
1854 for (int ox = 0; ox < out_real_width; ox++)
1855 {
1856 sum_real = 0.0;
1857 sum_imag = 0.0;
1858 for (int iy = 0; iy < in_height; iy++)
1859 {
1860 for (int ix = 0; ix < in_width; ix++)
1861 {
1862 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
1863 a = 2 * M_PI * ((iy * (OutEigenType)oy) / in_height + (ix * (OutEigenType)ox) / in_width);
Tai Ly307392a2023-05-12 21:42:19 +00001864 sum_real += in_val(n, iy, ix) * cos(a);
1865 sum_imag += -in_val(n, iy, ix) * sin(a);
Luke Hutton261b7b62023-01-10 14:50:31 +00001866 }
1867 }
1868 this->out_real->getTensor()(n, oy, ox) = sum_real;
1869 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1870 }
1871 }
1872 }
1873
1874 return GraphNode::eval();
1875}
1876
Tai Lya4d748b2023-03-28 22:06:56 +00001877template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001878OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001879 TosaAttributeBase* attribute_,
1880 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001881 : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001882{
1883 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001884 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001885
Kevin Cheng93a16282021-08-31 16:14:03 -07001886 INIT_ATTRIBUTE(TransposeConv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001887}
1888
Tai Lya4d748b2023-03-28 22:06:56 +00001889template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001890OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001891{
1892 if (attribute)
1893 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001894}
1895
Tai Lya4d748b2023-03-28 22:06:56 +00001896template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001897int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001898{
1899 if (validateRequiredOperands())
1900 return 1;
1901
1902 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1903 {
1904 return 1;
1905 }
1906
James Wardd34b3fc2023-01-18 14:51:25 +00001907 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001908 "OpTransposeConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001909
Eric Kunzee5e26762020-10-13 16:11:07 -07001910 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1911 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1912 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001913 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001914
TatWai Chong24594f52022-06-08 00:48:04 -07001915 if (attribute->out_pad().size() != 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001916 {
TatWai Chong24594f52022-06-08 00:48:04 -07001917 printNodeValidationError("OpTransposeConv2d: illegal size for attribute out_pad");
Eric Kunzee5e26762020-10-13 16:11:07 -07001918 return 1;
1919 }
1920
1921 if (attribute->stride().size() != 2)
1922 {
1923 printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
1924 return 1;
1925 }
1926
Eric Kunzee5e26762020-10-13 16:11:07 -07001927 if (attribute->output_shape().size() != 4)
1928 {
1929 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1930 return 1;
1931 }
1932
Kevin Cheng9fe17242021-11-10 01:04:39 +00001933 for (int32_t i : attribute->stride())
1934 {
1935 if (i < 1)
1936 {
1937 printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one");
1938 return 1;
1939 }
1940 }
1941
Eric Kunzee5e26762020-10-13 16:11:07 -07001942 for (int d = 0; d < 4; d++)
1943 {
1944 if (attribute->output_shape()[d] != this->output->getShape()[d])
1945 {
1946 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1947 return 1;
1948 }
1949 }
1950
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001951 int32_t IH = input->getShape()[1];
1952 int32_t IW = input->getShape()[2];
1953 int32_t OH = output->getShape()[1];
1954 int32_t OW = output->getShape()[2];
1955
1956 int32_t stride_y = attribute->stride()[0];
1957 int32_t stride_x = attribute->stride()[1];
1958 int32_t kernel_h = weight->getShape()[1];
1959 int32_t kernel_w = weight->getShape()[2];
1960
TatWai Chong24594f52022-06-08 00:48:04 -07001961 int32_t out_pad_top = attribute->out_pad()[0];
1962 int32_t out_pad_bottom = attribute->out_pad()[1];
1963 int32_t out_pad_left = attribute->out_pad()[2];
1964 int32_t out_pad_right = attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001965
Eric Kunzec1a97832022-07-01 16:56:09 -07001966 for (size_t i = 0; i < attribute->out_pad().size(); i++)
1967 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001968 ERROR_IF(attribute->out_pad()[i] <= -(weight->getShape()[(i / 2) + 1]),
1969 "OpTransposeConv2d: At least one out_pad value is larger than kernel size");
Eric Kunzec1a97832022-07-01 16:56:09 -07001970 }
1971
1972 int32_t H = (IH - 1) * stride_y + out_pad_top + out_pad_bottom + kernel_h;
1973 int32_t W = (IW - 1) * stride_x + out_pad_left + out_pad_right + kernel_w;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001974
1975 if ((OH != H) || (OW != W))
1976 {
1977 std::string msg = "OpTransposeConv2d: Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001978 std::to_string(H) + "," + std::to_string(W) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001979 printNodeValidationError(msg.c_str());
1980 return 1;
1981 }
1982
Tai Lya4d748b2023-03-28 22:06:56 +00001983 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
1984 "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
1985 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
1986 "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001987
Eric Kunzee5e26762020-10-13 16:11:07 -07001988 return 0;
1989}
1990
Tai Lya4d748b2023-03-28 22:06:56 +00001991template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001992int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001993{
1994 int in_batch = this->input->getShape()[0];
1995 int in_height = this->input->getShape()[1];
1996 int in_width = this->input->getShape()[2];
1997 int in_channels = this->input->getShape()[3];
1998
1999 int f_out_channels = this->weight->getShape()[0];
2000 int f_height = this->weight->getShape()[1];
2001 int f_width = this->weight->getShape()[2];
2002 int f_in_channels = this->weight->getShape()[3];
2003
2004 int b_out_channels = this->bias->getShape()[0];
2005
2006 int out_batch = this->output->getShape()[0];
2007 int out_height = this->output->getShape()[1];
2008 int out_width = this->output->getShape()[2];
2009 int out_channels = this->output->getShape()[3];
2010
TatWai Chong24594f52022-06-08 00:48:04 -07002011 int out_pad_top = this->attribute->out_pad()[0];
2012 int out_pad_bottom = this->attribute->out_pad()[1];
2013 int out_pad_left = this->attribute->out_pad()[2];
2014 int out_pad_right = this->attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002015
Jerry Gea793f462023-04-11 00:05:02 +00002016 int stride_y = this->attribute->stride()[0];
2017 int stride_x = this->attribute->stride()[1];
Eric Kunzee5e26762020-10-13 16:11:07 -07002018
Kevin Chengacb550f2021-06-29 15:32:19 -07002019 ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
2020 ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
2021 in_channels);
2022 ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
2023 f_out_channels, out_channels);
Tai Lya641dd52023-08-11 19:58:50 +00002024 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
2025 "OpTransposeConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07002026
Jerry Gea793f462023-04-11 00:05:02 +00002027 // Check Tosa Level
2028 auto tosa_level = g_func_config.tosa_level;
2029 LEVEL_CHECK(f_height <= tosa_level.MAX_KERNEL, "KH should be smaller than or equal to MAX_KERNEL");
2030 LEVEL_CHECK(f_width <= tosa_level.MAX_KERNEL, "KW should be smaller than or equal to MAX_KERNEL");
2031 LEVEL_CHECK(out_pad_top <= tosa_level.MAX_KERNEL, "out_pad_top should be smaller than or equal to MAX_KERNEL");
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002032 LEVEL_CHECK(out_pad_bottom <= tosa_level.MAX_KERNEL,
2033 "out_pad_bottom should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00002034 LEVEL_CHECK(out_pad_left <= tosa_level.MAX_KERNEL, "out_pad_left should be smaller than or equal to MAX_KERNEL");
2035 LEVEL_CHECK(out_pad_right <= tosa_level.MAX_KERNEL, "out_pad_right should be smaller than or equal to MAX_KERNEL");
2036 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
2037 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
2038
Eric Kunzee5e26762020-10-13 16:11:07 -07002039 DEBUG_INFO(OP,
2040 "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00002041 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002042 in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch,
2043 out_height, out_width, out_channels, stride_y, stride_x, out_pad_top, out_pad_bottom, out_pad_left,
2044 out_pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07002045
2046 TIn input_val = this->input->getTensor();
2047 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00002048 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002049 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002050 input_val = input_val - (InEigenType)attribute->input_zp();
2051 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07002052 }
2053
Tai Ly307392a2023-05-12 21:42:19 +00002054 TBias bias_val = this->bias->getTensor();
2055
2056 if (g_func_config.abs_mode)
2057 {
2058 // in abs_mode: take abs values of conv operands
2059 input_val = input_val.abs();
2060 weight_val = weight_val.abs();
2061 bias_val = bias_val.abs();
2062 }
2063
Eric Kunzee5e26762020-10-13 16:11:07 -07002064 Eigen::array<Eigen::Index, 4> reshape_dim;
2065 reshape_dim.fill(1);
2066 reshape_dim[3] = b_out_channels;
2067
2068 Eigen::array<Eigen::Index, 4> bcast;
2069 bcast[0] = out_batch;
2070 bcast[1] = out_height;
2071 bcast[2] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +00002072 bcast[3] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07002073
2074 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00002075 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07002076
2077 int out_x_origin, out_y_origin;
2078 int out_x, out_y;
2079
2080 // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
2081 for (int ob = 0; ob < out_batch; ob++)
2082 {
2083 for (int ih = 0; ih < in_height; ih++)
2084 {
2085 for (int iw = 0; iw < in_width; iw++)
2086 {
Jerry Gea793f462023-04-11 00:05:02 +00002087 out_x_origin = iw * stride_x + out_pad_left;
2088 out_y_origin = ih * stride_y + out_pad_top;
Eric Kunzee5e26762020-10-13 16:11:07 -07002089 for (int ic = 0; ic < in_channels; ic++)
2090 {
2091 for (int fh = 0; fh < f_height; fh++)
2092 {
2093 for (int fw = 0; fw < f_width; fw++)
2094 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002095 out_x = out_x_origin + fw;
2096 out_y = out_y_origin + fh;
Eric Kunzee5e26762020-10-13 16:11:07 -07002097 for (int oc = 0; oc < out_channels; oc++)
2098 {
2099 if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
2100 {
2101 this->output->getTensor()(ob, out_y, out_x, oc) +=
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002102 (OutEigenType)((AccEigenType)input_val(ob, ih, iw, ic) *
2103 (AccEigenType)weight_val(oc, fh, fw, ic));
Eric Kunzee5e26762020-10-13 16:11:07 -07002104 }
2105 }
2106 }
2107 }
2108 }
2109 }
2110 }
2111 }
2112
Tai Lya4d748b2023-03-28 22:06:56 +00002113 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07002114 {
James Ward8b390432022-08-12 20:48:56 +01002115 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
2116 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07002117 }
2118
2119 return GraphNode::eval();
2120}
2121
2122// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +01002123DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002124DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002125DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002126DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002127DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002128DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002129
James Wardd34b3fc2023-01-18 14:51:25 +00002130DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
2131DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
2132DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
2133DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
2134DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
2135DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +00002136DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002137
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002138// [in_t, weight_t, out_t]
James Wardd34b3fc2023-01-18 14:51:25 +00002139DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
2140DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
2141DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
2142DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
2143DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
2144DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
2145DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002146DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002147
James Wardd34b3fc2023-01-18 14:51:25 +00002148DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
2149DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
2150DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
2151DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
2152DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
2153DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
2154DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002155DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
Kevin Cheng1533b852021-09-01 12:51:58 -07002156
James Wardd34b3fc2023-01-18 14:51:25 +00002157DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
2158DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
2159DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
2160DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
2161DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
2162DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
2163DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002164DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002165
Luke Hutton57287132023-02-06 14:54:18 +00002166DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002167DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
Luke Hutton57287132023-02-06 14:54:18 +00002168
James Wardd34b3fc2023-01-18 14:51:25 +00002169DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
2170DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
2171DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
2172DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
2173DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
2174DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
2175DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002176DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002177
James Wardd34b3fc2023-01-18 14:51:25 +00002178DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
2179DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
2180DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
2181DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
2182DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
2183DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002184DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002185
James Ward8b390432022-08-12 20:48:56 +01002186DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002187DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002188DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002189DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002190DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002191DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002192
Luke Hutton261b7b62023-01-10 14:50:31 +00002193DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002194DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
Luke Hutton261b7b62023-01-10 14:50:31 +00002195
James Wardd34b3fc2023-01-18 14:51:25 +00002196DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
2197DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
2198DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
2199DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
2200DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
2201DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
2202DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002203DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);