blob: 96fd41c21a1c90d9a7094ff015a359e8ea05a239 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Hutton261b7b62023-01-10 14:50:31 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tensor_ops.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000017#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070018#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
Kevin Cheng9fe17242021-11-10 01:04:39 +000025int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
26 std::vector<int32_t> input_shape,
27 std::vector<int32_t> output_shape,
28 std::string& msg)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000029{
TatWai Chong86c403b2022-06-06 20:46:01 -070030 if (attribute->pad().size() != 4)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000031 {
32 msg = "illegal size for attribute padding";
33 return 1;
34 }
35
36 if (attribute->kernel().size() != 2)
37 {
38 msg = "illegal size for attribute kernel";
39 return 1;
40 }
41
42 if (attribute->stride().size() != 2)
43 {
44 msg = "illegal size for attribute stride";
45 return 1;
46 }
47
TatWai Chong86c403b2022-06-06 20:46:01 -070048 for (int32_t i : attribute->pad())
Kevin Cheng7eb93d72021-10-09 01:26:08 +000049 {
50 if (i < 0)
51 {
52 msg = "At least one pad is smaller than zero";
53 return 1;
54 }
55 }
56
57 for (int32_t i : attribute->kernel())
58 {
59 if (i < 1)
60 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000061 msg = "At least one kernel dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000062 return 1;
63 }
64 }
65
66 for (int32_t i : attribute->stride())
67 {
68 if (i < 1)
69 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000070 msg = "At least one stride dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000071 return 1;
72 }
73 }
74
75 int32_t IH = input_shape[1];
76 int32_t IW = input_shape[2];
77 int32_t OH = output_shape[1];
78 int32_t OW = output_shape[2];
79
TatWai Chong86c403b2022-06-06 20:46:01 -070080 int32_t pad_top = attribute->pad()[0];
81 int32_t pad_bottom = attribute->pad()[1];
82 int32_t pad_left = attribute->pad()[2];
83 int32_t pad_right = attribute->pad()[3];
Kevin Cheng7eb93d72021-10-09 01:26:08 +000084
85 int32_t stride_y = attribute->stride()[0];
86 int32_t stride_x = attribute->stride()[1];
87 int32_t kernel_y = attribute->kernel()[0];
88 int32_t kernel_x = attribute->kernel()[1];
89
90 if (pad_top >= kernel_y || pad_bottom >= kernel_y || pad_left >= kernel_x || pad_right >= kernel_x)
91 {
92 msg = "At least one pad is >= kernel dimension";
93 return 1;
94 }
95
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010096 int32_t full_H = IH + pad_top + pad_bottom - kernel_y;
97 int32_t full_W = IW + pad_left + pad_right - kernel_x;
98
Jerry Ge9c9c8da2023-07-19 23:08:16 +000099 if ((full_H % stride_y != 0) || (full_W % stride_x != 0))
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000100 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100101 msg = "Parameters must yield exact integer output dimensions";
102 return 1;
103 }
104
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000105 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100106 {
107 msg = "Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000108 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000109 return 1;
110 }
111
112 return 0;
113}
114
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000115int check_conv_attribute(tosa::TosaConvAttribute* attribute,
Tai Lya4d748b2023-03-28 22:06:56 +0000116 uint32_t conv_dimension,
117 std::vector<int32_t> input_shape,
118 std::vector<int32_t> output_shape,
119 std::vector<int32_t> weights,
120 uint32_t offset_kernel,
121 TOSA_REF_TYPE InDtype,
122 TOSA_REF_TYPE WeightDtype,
123 std::string& msg)
Kevin Cheng9fe17242021-11-10 01:04:39 +0000124{
TatWai Chong86c403b2022-06-06 20:46:01 -0700125 if (attribute->pad().size() != (2 * conv_dimension))
Kevin Cheng9fe17242021-11-10 01:04:39 +0000126 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700127 msg = "Illegal size for attribute pad";
Kevin Cheng9fe17242021-11-10 01:04:39 +0000128 return 1;
129 }
130
131 if (attribute->stride().size() != conv_dimension)
132 {
133 msg = "Illegal size for attribute stride";
134 return 1;
135 }
136
137 if (attribute->dilation().size() != conv_dimension)
138 {
139 msg = "Illegal size for attribute dilation";
140 return 1;
141 }
142
TatWai Chong86c403b2022-06-06 20:46:01 -0700143 for (int32_t i : attribute->pad())
Kevin Cheng9fe17242021-11-10 01:04:39 +0000144 {
145 if (i < 0)
146 {
147 msg = "At least one pad is smaller than zero";
148 return 1;
149 }
150 }
151
152 for (int32_t i : attribute->stride())
153 {
154 if (i < 1)
155 {
156 msg = "At least one stride dimension is smaller than one";
157 return 1;
158 }
159 }
160
161 for (int32_t i : attribute->dilation())
162 {
163 if (i < 1)
164 {
165 msg = "At least one dilation dimension is smaller than one";
166 return 1;
167 }
168 }
169
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100170 ASSERT_MSG(conv_dimension == 2 || conv_dimension == 3, "Unsupported convolution dimension")
171
TatWai Chongfd629052022-07-25 04:01:58 +0000172 int32_t offset_d = conv_dimension == 3 ? 1 : 0;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000173 int32_t ID = conv_dimension == 3 ? input_shape[1] : 1;
174 int32_t IH = input_shape[1 + offset_d];
175 int32_t IW = input_shape[2 + offset_d];
176 int32_t OD = conv_dimension == 3 ? output_shape[1] : 1;
177 int32_t OH = output_shape[1 + offset_d];
178 int32_t OW = output_shape[2 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100179
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000180 int32_t stride_d = conv_dimension == 3 ? attribute->stride()[0] : 1;
181 int32_t stride_y = attribute->stride()[0 + offset_d];
182 int32_t stride_x = attribute->stride()[1 + offset_d];
183 int32_t kernel_d = conv_dimension == 3 ? weights[offset_kernel] : 1;
184 int32_t kernel_h = weights[offset_kernel + offset_d];
185 int32_t kernel_w = weights[offset_kernel + 1 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100186 int32_t dilation_d = conv_dimension == 3 ? attribute->dilation()[0] : 1;
187 int32_t dilation_y = attribute->dilation()[0 + offset_d];
188 int32_t dilation_x = attribute->dilation()[1 + offset_d];
189
190 offset_d *= 2;
TatWai Chong86c403b2022-06-06 20:46:01 -0700191 int32_t pad_d0 = conv_dimension == 3 ? attribute->pad()[0] : 0;
192 int32_t pad_d1 = conv_dimension == 3 ? attribute->pad()[1] : 0;
193 int32_t pad_top = attribute->pad()[0 + offset_d];
194 int32_t pad_bottom = attribute->pad()[1 + offset_d];
195 int32_t pad_left = attribute->pad()[2 + offset_d];
196 int32_t pad_right = attribute->pad()[3 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100197
198 int32_t full_D = ID - 1 + pad_d0 + pad_d1 - (kernel_d - 1) * dilation_d;
199 int32_t full_H = IH - 1 + pad_top + pad_bottom - (kernel_h - 1) * dilation_y;
200 int32_t full_W = IW - 1 + pad_left + pad_right - (kernel_w - 1) * dilation_x;
201
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000202 if ((full_H % stride_y != 0) || (full_W % stride_x != 0) || (full_D % stride_d != 0))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100203 {
204 msg = "Parameters must yield exact integer output dimensions";
205 return 1;
206 }
207
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000208 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1) || (OD != (full_D / stride_d) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100209 {
210 std::string msg_d = "";
211 if (conv_dimension == 3)
212 {
213 msg_d += std::to_string((full_D / stride_d) + 1) + ",";
214 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000215 msg = "Mismatch between output shape provided and expected output shape (" + msg_d +
216 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100217 return 1;
218 }
219
Tai Lya4d748b2023-03-28 22:06:56 +0000220 if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0)
221 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000222 msg = "Input zero point must be zero for non-int8 data";
223 return 1;
224 }
Tai Lya4d748b2023-03-28 22:06:56 +0000225 if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0)
226 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 msg = "Weight zero point must be zero for non-int8 data";
228 return 1;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000229 }
230
231 return 0;
232}
233
Luke Hutton57287132023-02-06 14:54:18 +0000234int check_fft_shape(const std::vector<int32_t>& in_real,
235 const std::vector<int32_t>& in_imag,
236 const std::vector<int32_t>& out_real,
237 const std::vector<int32_t>& out_imag,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000238 std::string& msg)
239{
240 const bool is_rfft = in_imag.empty();
241 auto is_power_of_two = [](int32_t n) -> bool { return (n & (n - 1)) == 0 && n > 0; };
Luke Hutton57287132023-02-06 14:54:18 +0000242
243 if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2]))
244 {
245 msg = "Input height and width must be a power of two";
246 return 1;
247 }
248
249 // RFFT does not have a second input
250 if (!is_rfft)
251 {
252 bool input_check = true;
253 for (size_t i = 0; i < in_real.size(); i++)
254 {
255 if (in_real[i] != in_imag[i])
256 {
257 input_check = false;
258 break;
259 }
260 }
261 if (!input_check)
262 {
263 msg = "Mismatch between real input shape and imaginary input shape";
264 return 1;
265 }
266 }
267
268 bool output_check = true;
269 for (size_t i = 0; i < out_real.size(); i++)
270 {
271 if (out_real[i] != out_imag[i])
272 {
273 output_check = false;
274 break;
275 }
276 }
277 if (!output_check)
278 {
279 msg = "Mismatch between real output shape and imaginary output shape";
280 return 1;
281 }
282
283 if (in_real[0] != out_real[0])
284 {
285 msg = "Input and output batch size don't match";
286 return 1;
287 }
288 if (in_real[1] != out_real[1])
289 {
290 msg = "Input and output height don't match";
291 return 1;
292 }
293
294 if (is_rfft)
295 {
296 if (in_real[2] / 2 + 1 != out_real[2])
297 {
298 msg = "Output width is expected to match input width / 2 + 1";
299 return 1;
300 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000301 }
302 else
303 {
Luke Hutton57287132023-02-06 14:54:18 +0000304 if (in_real[2] != out_real[2])
305 {
306 msg = "Input and output width don't match";
307 return 1;
308 }
309 }
310
311 return 0;
312}
313
Tai Lya4d748b2023-03-28 22:06:56 +0000314template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000315OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700316 : GraphNode(sgt_, Op_ARGMAX, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700317{
318 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000319 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
321 INIT_ATTRIBUTE(Axis);
322}
323
Tai Lya4d748b2023-03-28 22:06:56 +0000324template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700325OpArgMax<Rank, Dtype>::~OpArgMax()
326{
327 if (attribute)
328 delete attribute;
329}
330
Tai Lya4d748b2023-03-28 22:06:56 +0000331template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700332int OpArgMax<Rank, Dtype>::checkTensorAttributes()
333{
334 if (validateRequiredOperands())
335 return 1;
336
Kevin Chengcc61be32021-10-14 17:09:57 -0700337 if (validateRequiredRank(inputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700338 {
339 return 1;
340 }
341
Kevin Chengcc61be32021-10-14 17:09:57 -0700342 int32_t output_rank = inputs[0]->getRank() - 1;
343 if (output_rank != outputs[0]->getRank())
344 {
345 printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
346 return 1;
347 }
348
Tai Lya4d748b2023-03-28 22:06:56 +0000349 if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32)
Kevin Chengcc61be32021-10-14 17:09:57 -0700350 {
351 printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
352 return 1;
353 }
354
Eric Kunzee5e26762020-10-13 16:11:07 -0700355 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
356 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
357
Kevin Chengcc61be32021-10-14 17:09:57 -0700358 if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
359 {
360 printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
361 return 1;
362 }
363
364 bool shape_check = true;
365 for (int32_t i = 0; i < input->getRank(); i++)
366 {
367 if (i < attribute->axis())
368 {
369 if (input->getShape()[i] != output->getShape()[i])
370 {
371 shape_check = false;
372 break;
373 }
374 }
375 else if (i > attribute->axis())
376 {
377 if (input->getShape()[i] != output->getShape()[i - 1])
378 {
379 shape_check = false;
380 break;
381 }
382 }
383 // No need to check i == axis
384 }
385 if (!shape_check)
386 {
387 printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
388 return 1;
389 }
390
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 return 0;
392}
393
Tai Lya4d748b2023-03-28 22:06:56 +0000394template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700395int OpArgMax<Rank, Dtype>::eval()
396{
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000397 // Check Tosa Level
398 auto tosa_level = g_func_config.tosa_level;
399 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
400
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
402
403 this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
404
405 return GraphNode::eval();
406}
407
Tai Lya4d748b2023-03-28 22:06:56 +0000408template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000409OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700410 : GraphNode(sgt_, Op_AVG_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700411{
412 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000413 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700414
Kevin Cheng93a16282021-08-31 16:14:03 -0700415 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -0700416}
417
Tai Lya4d748b2023-03-28 22:06:56 +0000418template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100419OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700420{
421 if (attribute)
422 delete attribute;
423}
424
Tai Lya4d748b2023-03-28 22:06:56 +0000425template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100426int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700427{
428 if (validateRequiredOperands())
429 return 1;
430
431 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
432 {
433 return 1;
434 }
435
436 if (inputs[0]->matchType(*outputs[0]))
437 {
438 printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
439 return 1;
440 }
441
442 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
443 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
444
Tai Lya4d748b2023-03-28 22:06:56 +0000445 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
446 "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
447 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0,
448 "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
Eric Kunzee5e26762020-10-13 16:11:07 -0700449
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000450 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000451 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700452 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000453 msg = "OpAvgPool2d: " + msg;
454 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return 1;
456 }
457
458 return 0;
459}
460
Eric Kunze830add42022-01-25 22:56:46 -0800461// This calculates the number of padding elements used for each location along an axis
462// Average pooling only divides by the number of elements used, not including padding.
463// This function uses left/right, but is also used for vertical padding with top/bottom
Tai Lya4d748b2023-03-28 22:06:56 +0000464template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
465ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(
466 int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
Eric Kunzee5e26762020-10-13 16:11:07 -0700467{
468 ETensor1<int32_t> result(out_size);
469
Eric Kunzee5e26762020-10-13 16:11:07 -0700470 result.setConstant(kernel_size);
471
Eric Kunze830add42022-01-25 22:56:46 -0800472 // adjust divisors on the left side for padding
473 // We start at the leftmost output element, and remove pad_left - (index * stride) elements
474 // until we have no more padding being used
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000475 for (int index = 0; (index <= pad_left / stride) && (index < out_size); index++)
476 {
Eric Kunze830add42022-01-25 22:56:46 -0800477 int32_t adjust = pad_left - (index * stride);
478 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 }
480
Eric Kunze830add42022-01-25 22:56:46 -0800481 // The process repeats on the right side. Padding starts taking effect as we
482 // near the rightmost input element. The first output element which touches
483 // padding is defined in the initialization of index below. Then we keep moving
484 // to the right, increasing padding until we get to the last output element.
485 int index = std::max(0, ((pad_left + in_size - kernel_size) / stride) + 1);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000486 for (; index < out_size; index++)
487 {
Eric Kunze830add42022-01-25 22:56:46 -0800488 int32_t adjust = ((index * stride) + kernel_size) - (pad_left + in_size);
489 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700490 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700491 return result;
492}
493
494// assuming input and output tensor have same scales like tflite reference
495// so no need to scale input and output
Tai Lya4d748b2023-03-28 22:06:56 +0000496template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100497int OpAvgPool2d<Dtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700498{
499 int in_batch = this->in->getShape()[0];
500 int in_height = this->in->getShape()[1];
501 int in_width = this->in->getShape()[2];
502 int in_channels = this->in->getShape()[3];
503
504 int out_batch = this->out->getShape()[0];
505 int out_height = this->out->getShape()[1];
506 int out_width = this->out->getShape()[2];
507 int out_channels = this->out->getShape()[3];
508
Kevin Chengacb550f2021-06-29 15:32:19 -0700509 ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
510 ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700511
TatWai Chong86c403b2022-06-06 20:46:01 -0700512 int pad_top = this->attribute->pad()[0];
513 int pad_bottom = this->attribute->pad()[1];
514 int pad_left = this->attribute->pad()[2];
515 int pad_right = this->attribute->pad()[3];
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000516 int kernel_y = this->attribute->kernel()[0];
517 int kernel_x = this->attribute->kernel()[1];
518 int stride_y = this->attribute->stride()[0];
519 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000520
521 // Check Tosa Level
522 auto tosa_level = g_func_config.tosa_level;
523 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
524 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
525 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
526 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
527 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
528 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
529 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
530 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -0700531
Tai Lya4d748b2023-03-28 22:06:56 +0000532 TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype());
James Ward8b390432022-08-12 20:48:56 +0100533
Eric Kunzee5e26762020-10-13 16:11:07 -0700534 DEBUG_INFO(OP,
535 "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
James Ward8b390432022-08-12 20:48:56 +0100536 "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
Jerry Gea793f462023-04-11 00:05:02 +0000537 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
538 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
540 Eigen::array<Eigen::Index, 2> im2col_input_dims;
Jerry Gea793f462023-04-11 00:05:02 +0000541 im2col_input_dims[0] = kernel_y * kernel_x;
Eric Kunzee5e26762020-10-13 16:11:07 -0700542 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
543
544 Eigen::array<Eigen::Index, 4> col2im_output_dims;
545 col2im_output_dims[0] = out_batch;
546 col2im_output_dims[1] = out_height;
547 col2im_output_dims[2] = out_width;
548 col2im_output_dims[3] = out_channels;
549
TatWai Chong86c403b2022-06-06 20:46:01 -0700550 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
551 pad[0] = std::make_pair(0, 0);
552 pad[1] = std::make_pair(pad_top, pad_bottom);
553 pad[2] = std::make_pair(pad_left, pad_right);
554 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700555
556 ETensor4<InEigenType> input_val = this->in->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000557 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700558 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000559 input_val = input_val - (InEigenType)attribute->input_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700560 }
561
TatWai Chong86c403b2022-06-06 20:46:01 -0700562 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700563
Tai Ly307392a2023-05-12 21:42:19 +0000564 if (g_func_config.abs_mode)
565 {
566 // in abs_mode: take abs values of input_padded
567 input_padded = input_padded.abs();
568 }
569
Eric Kunzee5e26762020-10-13 16:11:07 -0700570 // assuming input and output have same scales
571 // so input and output scaling is not required
572 // TODO: check if this assumption TOSA made
573
574 // extract_image_patches() output [N, KH, KW, H * W, C]
575 // transpose to [KH, KW, N, H * W, C]
576 // reshape to [KH * KW, N * H * W * C]
577 ETensor2<InEigenType> input_extract_patches =
Jerry Gea793f462023-04-11 00:05:02 +0000578 input_padded.extract_image_patches(kernel_y, kernel_x, stride_y, stride_x, 1, 1, Eigen::PADDING_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700579 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
580 .reshape(im2col_input_dims);
581
582 // 1D result with [N * H * W * C]
583 ETensor1<AccEigenType> out_1d(this->out->getElementCount());
584 out_1d.setZero();
585
586 // sum pool
587 for (size_t i = 0; i < this->out->getElementCount(); i++)
588 {
Jerry Gea793f462023-04-11 00:05:02 +0000589 for (int32_t j = 0; j < kernel_y * kernel_x; j++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700590 {
591 out_1d(i) += (AccEigenType)input_extract_patches(j, i);
592 }
593 }
594
595 // reshape result to [N, H, W, C] and divide with div_map
596 ETensor4<AccEigenType> sum = out_1d.reshape(col2im_output_dims);
597
598 // calculate 1d height/width div_map (number of elements this pooling window covers)
599 // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C]
Jeremy Johnson44eb88d2023-04-24 09:49:58 +0100600 ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_y, stride_y, pad_top, pad_bottom);
Jerry Gea793f462023-04-11 00:05:02 +0000601 ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_x, stride_x, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700602 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
603 Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
604
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000605 ETensor2<int32_t> dm2_w = div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width });
606 ETensor2<int32_t> dm2_h = div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 });
607 ETensor4<int32_t> div_map = dm2_h.contract(dm2_w, contract_dims)
608 .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
609 .broadcast(bcast);
Tai Lya4d748b2023-03-28 22:06:56 +0000610 if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
611 Dtype != TOSA_REF_TYPE_FP64)
Eric Kunzee5e26762020-10-13 16:11:07 -0700612 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700613 try
614 {
615 this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
616 int32_t multiplier, shift;
617 TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift);
Eric Kunzee5e26762020-10-13 16:11:07 -0700618
Kevin Chengacb550f2021-06-29 15:32:19 -0700619 return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false);
620 });
621 }
622 catch (std::string desc)
623 {
624 REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
625 }
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000626 this->out->getTensor() = this->out->getTensor() + (OutEigenType)(attribute->output_zp());
Eric Kunzee5e26762020-10-13 16:11:07 -0700627 this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
628 this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
629 }
630 else
631 {
James Ward24dbc422022-10-19 12:20:31 +0100632 // Case for float-types
Eric Kunzee5e26762020-10-13 16:11:07 -0700633 this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
634 }
635
636 return GraphNode::eval();
637}
638
Tai Lya4d748b2023-03-28 22:06:56 +0000639template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000640OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700641 : GraphNode(sgt_, Op_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700642{
643 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000644 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700645
Kevin Cheng93a16282021-08-31 16:14:03 -0700646 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700647}
648
Tai Lya4d748b2023-03-28 22:06:56 +0000649template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000650OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700651{
652 if (attribute)
653 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700654}
655
Tai Lya4d748b2023-03-28 22:06:56 +0000656template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000657int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700658{
659 if (validateRequiredOperands())
660 return 1;
661
662 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
663 {
664 return 1;
665 }
666
667 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
668 if (inputs[2]->getRank() != 1)
669 {
670 printNodeValidationError("OpConv2d: bias tensor must be rank 1");
671 }
672
James Wardd34b3fc2023-01-18 14:51:25 +0000673 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000674 "OpConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700675
Eric Kunzee5e26762020-10-13 16:11:07 -0700676 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
677 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
678 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100679 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700680
Kevin Cheng9fe17242021-11-10 01:04:39 +0000681 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000682 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000683 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700684 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000685 msg = "OpConv2d: " + msg;
686 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700687 return 1;
688 }
689
Eric Kunzee5e26762020-10-13 16:11:07 -0700690 return 0;
691}
692
Tai Lya4d748b2023-03-28 22:06:56 +0000693template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000694int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700695{
696 int in_batch = this->input->getShape()[0];
697 int in_height = this->input->getShape()[1];
698 int in_width = this->input->getShape()[2];
699 int in_channels = this->input->getShape()[3];
700
701 int f_out_channels = this->weight->getShape()[0];
702 int f_height = this->weight->getShape()[1];
703 int f_width = this->weight->getShape()[2];
704 int f_in_channels = this->weight->getShape()[3];
705
706 int b_out_channels = this->bias->getShape()[0];
707
708 int out_batch = this->output->getShape()[0];
709 int out_height = this->output->getShape()[1];
710 int out_width = this->output->getShape()[2];
711 int out_channels = this->output->getShape()[3];
712
Kevin Chengacb550f2021-06-29 15:32:19 -0700713 ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
714 ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
715 in_channels);
716 ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
717 out_channels);
718 ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700719
TatWai Chong86c403b2022-06-06 20:46:01 -0700720 int pad_top = this->attribute->pad()[0];
721 int pad_bottom = this->attribute->pad()[1];
722 int pad_left = this->attribute->pad()[2];
723 int pad_right = this->attribute->pad()[3];
724
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000725 int stride_y = this->attribute->stride()[0];
726 int stride_x = this->attribute->stride()[1];
727 int dilation_y = this->attribute->dilation()[0];
728 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000729
730 // Check Tosa Level
731 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000732 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
733 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
734 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
735 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000736 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
737 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
738 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
739 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
740 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
741 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -0700742
743 DEBUG_INFO(OP,
744 "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000745 "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -0700746 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000747 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
748 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700749
750 // GEMM-conv2d, left matrix is input, right matrix is weight
751 Eigen::array<Eigen::Index, 2> im2col_input_dims;
752 im2col_input_dims[0] = out_batch * out_height * out_width;
753 im2col_input_dims[1] = f_height * f_width * f_in_channels;
754
755 Eigen::array<Eigen::Index, 2> im2col_weight_dims;
756 im2col_weight_dims[0] = f_height * f_width * f_in_channels;
757 im2col_weight_dims[1] = f_out_channels;
758
759 Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
760 bias_reshaped_dims[0] = 1;
761 bias_reshaped_dims[1] = b_out_channels;
762
763 Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
764 weight_zp_bcast_dims[0] = f_height;
765 weight_zp_bcast_dims[1] = f_width;
766 weight_zp_bcast_dims[2] = f_in_channels;
767
768 Eigen::array<Eigen::Index, 2> bias_bcast_dims;
769 bias_bcast_dims[0] = out_batch * out_height * out_width;
770 bias_bcast_dims[1] = 1;
771
772 Eigen::array<Eigen::Index, 4> col2im_output_dims;
773 col2im_output_dims[0] = out_batch;
774 col2im_output_dims[1] = out_height;
775 col2im_output_dims[2] = out_width;
776 col2im_output_dims[3] = out_channels;
777
778 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
779
TatWai Chong86c403b2022-06-06 20:46:01 -0700780 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
781 pad[0] = std::make_pair(0, 0);
782 pad[1] = std::make_pair(pad_top, pad_bottom);
783 pad[2] = std::make_pair(pad_left, pad_right);
784 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700785
786 TIn input_val = this->input->getTensor();
787 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000788 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700789 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000790 input_val = input_val - (InEigenType)attribute->input_zp();
791 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700792 }
793
TatWai Chong86c403b2022-06-06 20:46:01 -0700794 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700795
Tai Ly307392a2023-05-12 21:42:19 +0000796 TBias bias_val = this->bias->getTensor();
797
798 if (g_func_config.abs_mode)
799 {
800 // in abs_mode: take abs values of conv operands
801 input_padded = input_padded.abs();
802 weight_val = weight_val.abs();
803 bias_val = bias_val.abs();
804 }
805
Eric Kunzee5e26762020-10-13 16:11:07 -0700806 // extract_image_patches() output [N, KH, KW, H * W, C]
807 // need to transpose to [N, H * W, KH, KW, C]
808 ETensor5<InEigenType> input_extract_patches =
809 input_padded
Jerry Gea793f462023-04-11 00:05:02 +0000810 .extract_image_patches(f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700811 .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
812
813 // reshape input to [N * H * W, KH * KW * C]
814 ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
815
816 // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
817 ETensor2<WeightEigenType> im2col_weight =
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000818 weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
Eric Kunzee5e26762020-10-13 16:11:07 -0700819
820 // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
821 // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
Tai Ly307392a2023-05-12 21:42:19 +0000822 ETensor2<OutEigenType> bias_2d =
823 (bias_val.reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700824
825 // output matrix is [N * H * W, C]
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000826 ETensor2<OutEigenType> contracted_result = (im2col_input.template cast<AccEigenType>().contract(
827 im2col_weight.template cast<AccEigenType>(), contract_dims))
828 .template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700829
830 // adding bias
James Ward8b390432022-08-12 20:48:56 +0100831 ETensor2<OutEigenType> biased_output = contracted_result + bias_2d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700832
833 // reshape back to [N, H, W, C]
834 this->output->getTensor() = biased_output.reshape(col2im_output_dims);
835
Tai Lya4d748b2023-03-28 22:06:56 +0000836 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -0700837 {
James Ward8b390432022-08-12 20:48:56 +0100838 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
839 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700840 }
841
842 return GraphNode::eval();
843}
844
Tai Lya4d748b2023-03-28 22:06:56 +0000845template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000846OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Cheng1533b852021-09-01 12:51:58 -0700847 : GraphNode(sgt_, Op_CONV3D, id_)
848{
849 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000850 setRequiredRank(5, 5);
Kevin Cheng1533b852021-09-01 12:51:58 -0700851
852 INIT_ATTRIBUTE(Conv);
Kevin Cheng1533b852021-09-01 12:51:58 -0700853}
854
Tai Lya4d748b2023-03-28 22:06:56 +0000855template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000856OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
Kevin Cheng1533b852021-09-01 12:51:58 -0700857{
858 if (attribute)
859 delete attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700860}
861
Tai Lya4d748b2023-03-28 22:06:56 +0000862template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000863int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Kevin Cheng1533b852021-09-01 12:51:58 -0700864{
865 if (validateRequiredOperands())
866 return 1;
867
868 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
869 {
870 return 1;
871 }
872
873 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
874 if (inputs[2]->getRank() != 1)
875 {
876 printNodeValidationError("OpConv3d: bias tensor must be rank 1");
877 }
878
James Wardd34b3fc2023-01-18 14:51:25 +0000879 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000880 "OpConv3d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700881
Kevin Cheng1533b852021-09-01 12:51:58 -0700882 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
883 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
884 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100885 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng1533b852021-09-01 12:51:58 -0700886
Kevin Cheng9fe17242021-11-10 01:04:39 +0000887 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000888 if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000889 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Kevin Cheng1533b852021-09-01 12:51:58 -0700890 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000891 msg = "OpConv3d: " + msg;
892 printNodeValidationError(msg.c_str());
Kevin Cheng1533b852021-09-01 12:51:58 -0700893 return 1;
894 }
895
Kevin Cheng1533b852021-09-01 12:51:58 -0700896 return 0;
897}
898
Tai Lya4d748b2023-03-28 22:06:56 +0000899template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000900int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
Kevin Cheng1533b852021-09-01 12:51:58 -0700901{
902 int in_batch = this->input->getShape()[0];
903 int in_depth = this->input->getShape()[1];
904 int in_height = this->input->getShape()[2];
905 int in_width = this->input->getShape()[3];
906 int in_channels = this->input->getShape()[4];
907
908 int f_out_channels = this->weight->getShape()[0];
909 int f_depth = this->weight->getShape()[1];
910 int f_height = this->weight->getShape()[2];
911 int f_width = this->weight->getShape()[3];
912 int f_in_channels = this->weight->getShape()[4];
913
914 int b_out_channels = this->bias->getShape()[0];
915
916 int out_batch = this->output->getShape()[0];
917 int out_depth = this->output->getShape()[1];
918 int out_height = this->output->getShape()[2];
919 int out_width = this->output->getShape()[3];
920 int out_channels = this->output->getShape()[4];
921
922 ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
923 ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
924 in_channels);
925 ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
926 out_channels);
927 ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels);
928
TatWai Chong86c403b2022-06-06 20:46:01 -0700929 int pad_d0 = this->attribute->pad()[0];
930 int pad_d1 = this->attribute->pad()[1];
931 int pad_top = this->attribute->pad()[2];
932 int pad_bottom = this->attribute->pad()[3];
933 int pad_left = this->attribute->pad()[4];
934 int pad_right = this->attribute->pad()[5];
935
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000936 int stride_d = this->attribute->stride()[0];
937 int stride_y = this->attribute->stride()[1];
938 int stride_x = this->attribute->stride()[2];
TatWai Chong86c403b2022-06-06 20:46:01 -0700939
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000940 int dilation_d = this->attribute->dilation()[0];
941 int dilation_y = this->attribute->dilation()[1];
942 int dilation_x = this->attribute->dilation()[2];
Jerry Gea793f462023-04-11 00:05:02 +0000943
944 // Check Tosa Level
945 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000946 LEVEL_CHECK(dilation_d * f_depth <= tosa_level.MAX_KERNEL,
947 "dilation_d * KD should be smaller than or equal to MAX_KERNEL");
948 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
949 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
950 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
951 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000952 LEVEL_CHECK(pad_d0 <= tosa_level.MAX_KERNEL, "pad_d0 should be smaller than or equal to MAX_KERNEL");
953 LEVEL_CHECK(pad_d1 <= tosa_level.MAX_KERNEL, "pad_d1 should be smaller than or equal to MAX_KERNEL");
954 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
955 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
956 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
957 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
958 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
959 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
960 LEVEL_CHECK(stride_d <= tosa_level.MAX_STRIDE, "stride_d should be smaller than or equal to MAX_STRIDE");
Kevin Cheng1533b852021-09-01 12:51:58 -0700961
962 DEBUG_INFO(
963 OP,
964 "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000965 "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
Kevin Cheng1533b852021-09-01 12:51:58 -0700966 in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
Jerry Gea793f462023-04-11 00:05:02 +0000967 out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_y, stride_x, dilation_d, dilation_y,
968 dilation_x, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
Kevin Cheng1533b852021-09-01 12:51:58 -0700969
TatWai Chong86c403b2022-06-06 20:46:01 -0700970 Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
971 pad[0] = std::make_pair(0, 0);
972 pad[1] = std::make_pair(pad_d0, pad_d1);
973 pad[2] = std::make_pair(pad_top, pad_bottom);
974 pad[3] = std::make_pair(pad_left, pad_right);
975 pad[4] = std::make_pair(0, 0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700976
977 TIn input_val = this->input->getTensor();
978 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000979 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Kevin Cheng1533b852021-09-01 12:51:58 -0700980 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000981 input_val = input_val - (InEigenType)attribute->input_zp();
982 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Kevin Cheng1533b852021-09-01 12:51:58 -0700983 }
984
TatWai Chong86c403b2022-06-06 20:46:01 -0700985 ETensor5<InEigenType> input_padded = input_val.pad(pad);
Kevin Cheng1533b852021-09-01 12:51:58 -0700986
Tai Ly307392a2023-05-12 21:42:19 +0000987 TBias bias_val = this->bias->getTensor();
988
989 if (g_func_config.abs_mode)
990 {
991 // in abs_mode: take abs values of conv operands
992 input_padded = input_padded.abs();
993 weight_val = weight_val.abs();
994 bias_val = bias_val.abs();
995 }
996
Kevin Cheng1533b852021-09-01 12:51:58 -0700997 // 1. initialize with bias
998 Eigen::array<Eigen::Index, 5> reshape_dim;
999 reshape_dim.fill(1);
1000 reshape_dim[4] = b_out_channels;
1001
1002 Eigen::array<Eigen::Index, 5> bcast;
1003 bcast[0] = out_batch;
1004 bcast[1] = out_depth;
1005 bcast[2] = out_height;
1006 bcast[3] = out_width;
1007 bcast[4] = 1;
Tai Ly307392a2023-05-12 21:42:19 +00001008 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Kevin Cheng1533b852021-09-01 12:51:58 -07001009
1010 // 2. direct convolution
James Ward8b390432022-08-12 20:48:56 +01001011 AccEigenType acc(0.0);
Kevin Cheng1533b852021-09-01 12:51:58 -07001012 int d_idx, h_idx, w_idx;
1013
1014 for (int ob = 0; ob < out_batch; ob++)
1015 {
1016 for (int od = 0; od < out_depth; od++)
1017 {
1018 for (int oh = 0; oh < out_height; oh++)
1019 {
1020 for (int ow = 0; ow < out_width; ow++)
1021 {
1022 for (int oc = 0; oc < out_channels; oc++)
1023 {
Eric Kunze7edb34c2022-05-16 17:34:40 -07001024 // Initialize accumulator with bias value
James Ward8b390432022-08-12 20:48:56 +01001025 acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc);
Kevin Cheng1533b852021-09-01 12:51:58 -07001026 for (int fd = 0; fd < f_depth; fd++)
1027 {
1028 d_idx = od * stride_d + fd * dilation_d;
1029 for (int fh = 0; fh < f_height; fh++)
1030 {
Jerry Gea793f462023-04-11 00:05:02 +00001031 h_idx = oh * stride_y + fh * dilation_y;
Kevin Cheng1533b852021-09-01 12:51:58 -07001032 for (int fw = 0; fw < f_width; fw++)
1033 {
Jerry Gea793f462023-04-11 00:05:02 +00001034 w_idx = ow * stride_x + fw * dilation_x;
Kevin Cheng1533b852021-09-01 12:51:58 -07001035 for (int ic = 0; ic < in_channels; ic++)
1036 {
1037 acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
1038 (AccEigenType)weight_val(oc, fd, fh, fw, ic));
1039 }
1040 }
1041 }
1042 }
James Ward8b390432022-08-12 20:48:56 +01001043 this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc;
Kevin Cheng1533b852021-09-01 12:51:58 -07001044 }
1045 }
1046 }
1047 }
1048 }
1049
Tai Lya4d748b2023-03-28 22:06:56 +00001050 if (OutDtype == TOSA_REF_TYPE_INT48)
Kevin Cheng1533b852021-09-01 12:51:58 -07001051 {
James Ward8b390432022-08-12 20:48:56 +01001052 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1053 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Kevin Cheng1533b852021-09-01 12:51:58 -07001054 }
1055
1056 return GraphNode::eval();
1057}
1058
Tai Lya4d748b2023-03-28 22:06:56 +00001059template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001060OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001061 TosaAttributeBase* attribute_,
1062 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001063 : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001064{
1065 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001066 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001067
Kevin Cheng93a16282021-08-31 16:14:03 -07001068 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001069}
1070
Tai Lya4d748b2023-03-28 22:06:56 +00001071template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001072OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001073{
1074 if (attribute)
1075 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001076}
1077
Tai Lya4d748b2023-03-28 22:06:56 +00001078template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001079int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001080{
1081 if (validateRequiredOperands())
1082 return 1;
1083
1084 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1085 {
1086 return 1;
1087 }
1088
1089 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
1090 if (inputs[2]->getRank() != 1)
1091 {
1092 printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
1093 }
1094
James Wardd34b3fc2023-01-18 14:51:25 +00001095 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001096 "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001097
Eric Kunzee5e26762020-10-13 16:11:07 -07001098 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1099 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1100 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001101 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001102
Kevin Cheng9fe17242021-11-10 01:04:39 +00001103 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001104 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001105 weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001106 {
Kevin Cheng9fe17242021-11-10 01:04:39 +00001107 msg = "OpDepthwiseConv2d: " + msg;
1108 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001109 return 1;
1110 }
1111
Eric Kunzee5e26762020-10-13 16:11:07 -07001112 return 0;
1113}
1114
Tai Lya4d748b2023-03-28 22:06:56 +00001115template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001116int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001117{
1118 int in_batch = this->input->getShape()[0];
1119 int in_height = this->input->getShape()[1];
1120 int in_width = this->input->getShape()[2];
1121 int in_channels = this->input->getShape()[3];
1122
1123 int f_height = this->weight->getShape()[0];
1124 int f_width = this->weight->getShape()[1];
1125 int f_in_channels = this->weight->getShape()[2];
1126 int f_multiplier = this->weight->getShape()[3];
1127
1128 int b_out_channels = this->bias->getShape()[0];
1129
1130 int out_batch = this->output->getShape()[0];
1131 int out_height = this->output->getShape()[1];
1132 int out_width = this->output->getShape()[2];
1133 int out_channels = this->output->getShape()[3];
1134
Kevin Chengacb550f2021-06-29 15:32:19 -07001135 ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1136 ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1137 in_channels);
1138 ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
1139 in_channels * f_multiplier, out_channels);
1140 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
1141 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001142
TatWai Chong86c403b2022-06-06 20:46:01 -07001143 int pad_top = this->attribute->pad()[0];
1144 int pad_bottom = this->attribute->pad()[1];
1145 int pad_left = this->attribute->pad()[2];
1146 int pad_right = this->attribute->pad()[3];
1147
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001148 int stride_y = this->attribute->stride()[0];
1149 int stride_x = this->attribute->stride()[1];
1150 int dilation_y = this->attribute->dilation()[0];
1151 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001152
1153 // Check Tosa Level
1154 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001155 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
1156 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
1157 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
1158 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00001159 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1160 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1161 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1162 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
1163 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1164 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -07001165
1166 DEBUG_INFO(OP,
1167 "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00001168 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -07001169 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001170 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
1171 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001172
TatWai Chong86c403b2022-06-06 20:46:01 -07001173 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1174 pad[0] = std::make_pair(0, 0);
1175 pad[1] = std::make_pair(pad_top, pad_bottom);
1176 pad[2] = std::make_pair(pad_left, pad_right);
1177 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001178
1179 TIn input_val = this->input->getTensor();
1180 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001181 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001182 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001183 input_val = input_val - (InEigenType)attribute->input_zp();
1184 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001185 }
1186
TatWai Chong86c403b2022-06-06 20:46:01 -07001187 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -07001188
Tai Ly307392a2023-05-12 21:42:19 +00001189 TBias bias_val = this->bias->getTensor();
1190
1191 if (g_func_config.abs_mode)
1192 {
1193 // in abs_mode: take abs values of conv operands
1194 input_padded = input_padded.abs();
1195 weight_val = weight_val.abs();
1196 bias_val = bias_val.abs();
1197 }
1198
Eric Kunzee5e26762020-10-13 16:11:07 -07001199 // GEMM doesn't fit well with DepthwiseConv2d
TatWai Chong86c403b2022-06-06 20:46:01 -07001200 // 1. use extract_image_patches() to handle stride/dilation/pad
Eric Kunzee5e26762020-10-13 16:11:07 -07001201 // 2. perform direct convolution
1202
1203 // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
1204 ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
Jerry Gea793f462023-04-11 00:05:02 +00001205 f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID);
Eric Kunzee5e26762020-10-13 16:11:07 -07001206
1207 Eigen::array<Eigen::Index, 4> reshape_dim;
1208 reshape_dim.fill(1);
1209 reshape_dim[3] = b_out_channels;
1210
1211 Eigen::array<Eigen::Index, 4> bcast;
1212 bcast[0] = out_batch;
1213 bcast[1] = out_height;
1214 bcast[2] = out_width;
1215 bcast[3] = 1;
1216
1217 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00001218 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001219
1220 // 2. direct depthwise convolution
1221 for (int ob = 0; ob < out_batch; ob++)
1222 {
1223 for (int oh = 0; oh < out_height; oh++)
1224 {
1225 for (int ow = 0; ow < out_width; ow++)
1226 {
1227 for (int ic = 0; ic < in_channels; ic++)
1228 {
1229 for (int cm = 0; cm < f_multiplier; cm++)
1230 {
1231 for (int fh = 0; fh < f_height; fh++)
1232 {
1233 for (int fw = 0; fw < f_width; fw++)
1234 {
James Ward8b390432022-08-12 20:48:56 +01001235 // Perform multiplication in AccEigenType then cast to OutEigenType
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001236 this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) += (OutEigenType)(
1237 (AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
1238 (AccEigenType)weight_val(fh, fw, ic, cm));
Eric Kunzee5e26762020-10-13 16:11:07 -07001239 }
1240 }
1241 }
1242 }
1243 }
1244 }
1245 }
1246
Tai Lya4d748b2023-03-28 22:06:56 +00001247 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001248 {
James Ward8b390432022-08-12 20:48:56 +01001249 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1250 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001251 }
1252
1253 return GraphNode::eval();
1254}
1255
Tai Lya4d748b2023-03-28 22:06:56 +00001256template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001257OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001258 TosaAttributeBase* attribute_,
1259 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001260 : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001261{
1262 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001263 setRequiredRank(2, 2);
Eric Kunzee5e26762020-10-13 16:11:07 -07001264
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001265 INIT_ATTRIBUTE(FullyConnected);
Eric Kunzee5e26762020-10-13 16:11:07 -07001266}
1267
Tai Lya4d748b2023-03-28 22:06:56 +00001268template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001269OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
Eric Kunzee5e26762020-10-13 16:11:07 -07001270{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001271 if (attribute)
1272 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001273}
1274
Tai Lya4d748b2023-03-28 22:06:56 +00001275template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001276int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001277{
1278 if (validateRequiredOperands())
1279 return 1;
1280
1281 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1282 {
1283 return 1;
1284 }
1285
1286 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1287 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1288 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
1289
1290 if (input->getShape()[1] != weight->getShape()[1])
1291 {
1292 printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
1293 return 1;
1294 }
1295
1296 if (weight->getShape()[0] != bias->getShape()[0])
1297 {
1298 printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
1299 return 1;
1300 }
1301
James Wardd34b3fc2023-01-18 14:51:25 +00001302 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001303 "OpFullyConnected: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001304
James Ward8b390432022-08-12 20:48:56 +01001305 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001306
Tai Lya4d748b2023-03-28 22:06:56 +00001307 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
1308 "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
1309 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
1310 "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001311
Eric Kunzee5e26762020-10-13 16:11:07 -07001312 return 0;
1313}
1314
Tai Lya4d748b2023-03-28 22:06:56 +00001315template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001316int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001317{
1318 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1319 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1320
1321 Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
1322
1323 Eigen::array<Eigen::Index, 2> bias_reshape;
1324 bias_reshape[0] = 1;
1325 bias_reshape[1] = this->bias->getShape()[0];
1326
1327 Eigen::array<Eigen::Index, 2> bias_bcast;
1328 bias_bcast[0] = this->input->getShape()[0];
1329 bias_bcast[1] = 1;
1330
1331 TIn input_val = this->input->getTensor();
1332 TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
Tai Lya4d748b2023-03-28 22:06:56 +00001333 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001334 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001335 input_val = input_val - (InEigenType)attribute->input_zp();
1336 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001337 }
1338
Tai Ly307392a2023-05-12 21:42:19 +00001339 TBias bias_val = this->bias->getTensor();
1340
1341 if (g_func_config.abs_mode)
1342 {
1343 // in abs_mode: take abs values of conv operands
1344 input_val = input_val.abs();
1345 weight_val = weight_val.abs();
1346 bias_val = bias_val.abs();
1347 }
1348
1349 this->output->getTensor() = input_val.template cast<AccEigenType>()
1350 .contract(weight_val.template cast<AccEigenType>(), dims)
1351 .template cast<OutEigenType>() +
1352 bias_val.reshape(bias_reshape).broadcast(bias_bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001353
Tai Lya4d748b2023-03-28 22:06:56 +00001354 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001355 {
James Ward8b390432022-08-12 20:48:56 +01001356 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1357 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001358 }
1359 return GraphNode::eval();
1360}
1361
Tai Lya4d748b2023-03-28 22:06:56 +00001362template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001363OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001364 : GraphNode(sgt_, Op_MATMUL, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001365{
1366 setRequiredOperands(2, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001367 setRequiredRank(3, 3);
Eric Kunzee5e26762020-10-13 16:11:07 -07001368
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001369 INIT_ATTRIBUTE(MatMul);
Eric Kunzee5e26762020-10-13 16:11:07 -07001370}
1371
Tai Lya4d748b2023-03-28 22:06:56 +00001372template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001373OpMatMul<Dtype, OutDtype>::~OpMatMul()
Eric Kunzee5e26762020-10-13 16:11:07 -07001374{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001375 if (attribute)
1376 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001377}
1378
Tai Lya4d748b2023-03-28 22:06:56 +00001379template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001380int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001381{
1382 if (validateRequiredOperands())
1383 return 1;
1384
1385 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1386 {
1387 return 1;
1388 }
1389
James Wardd34b3fc2023-01-18 14:51:25 +00001390 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001391 "OpMatMul: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001392
Kevin Cheng2d60f002021-06-09 14:18:32 -07001393 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1394 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
James Ward8b390432022-08-12 20:48:56 +01001395 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001396
Kevin Cheng2d60f002021-06-09 14:18:32 -07001397 ASSERT_MEM(a && b && output);
1398
1399 // a: [N, H, C]
1400 // b: [N, C, W]
1401 // c: [N, H, W]
1402
1403 // Check N
1404 if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001405 {
Kevin Cheng2d60f002021-06-09 14:18:32 -07001406 printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
Eric Kunzee5e26762020-10-13 16:11:07 -07001407 return 1;
1408 }
Kevin Cheng2d60f002021-06-09 14:18:32 -07001409 N = a->getShape()[0];
Eric Kunzee5e26762020-10-13 16:11:07 -07001410
Kevin Cheng2d60f002021-06-09 14:18:32 -07001411 // Check C
1412 if (a->getShape()[2] != b->getShape()[1])
1413 {
1414 printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
1415 return 1;
1416 }
1417 C = a->getShape()[2];
1418
1419 // Check H
1420 if (a->getShape()[1] != output->getShape()[1])
1421 {
1422 printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
1423 return 1;
1424 }
1425 H = a->getShape()[1];
1426
1427 // Check W
1428 if (b->getShape()[2] != output->getShape()[2])
1429 {
1430 printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
1431 return 1;
1432 }
1433 W = b->getShape()[2];
Eric Kunzee5e26762020-10-13 16:11:07 -07001434
Tai Lya4d748b2023-03-28 22:06:56 +00001435 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0,
1436 "OpMatMul: A zeropoint must be zero for non int8_t data");
1437 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0,
1438 "OpMatMul: B zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001439
Eric Kunzee5e26762020-10-13 16:11:07 -07001440 return 0;
1441}
1442
Tai Lya4d748b2023-03-28 22:06:56 +00001443template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001444int OpMatMul<Dtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001445{
1446 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1447 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1448
1449 TIn a_val = this->a->getTensor();
1450 TIn b_val = this->b->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001451 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001452 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001453 a_val = a_val - (InEigenType)attribute->a_zp();
1454 b_val = b_val - (InEigenType)attribute->b_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001455 }
1456
Tai Ly307392a2023-05-12 21:42:19 +00001457 if (g_func_config.abs_mode)
1458 {
1459 // in abs_mode: take abs values of matmul operands
1460 a_val = a_val.abs();
1461 b_val = b_val.abs();
1462 }
1463
Kevin Cheng2d60f002021-06-09 14:18:32 -07001464 Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
1465 Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
1466 Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
1467
1468 Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
1469 Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
1470
1471 Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
1472 Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
1473
1474 // Iterate N dimension.
1475 for (int i = 0; i < N; i++)
1476 {
1477 a_begin_array[0] = i;
1478 b_begin_array[0] = i;
1479
1480 TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
1481 TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
1482 TAccRank2 output_rank2_val =
1483 a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
James Ward8b390432022-08-12 20:48:56 +01001484 TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast<OutEigenType>();
Kevin Cheng2d60f002021-06-09 14:18:32 -07001485 if (i == 0)
1486 {
1487 this->output->getTensor() = output_rank3_val;
1488 }
1489 else
1490 {
James Ward8b390432022-08-12 20:48:56 +01001491 TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001492 this->output->getTensor() = temp;
1493 }
1494 }
Eric Kunzee5e26762020-10-13 16:11:07 -07001495
Tai Lya4d748b2023-03-28 22:06:56 +00001496 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001497 {
James Ward8b390432022-08-12 20:48:56 +01001498 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1499 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001500 }
1501
1502 return GraphNode::eval();
1503}
1504
Tai Lya4d748b2023-03-28 22:06:56 +00001505template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001506OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001507 : GraphNode(sgt_, Op_MAX_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001508{
1509 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001510 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001511
Kevin Cheng93a16282021-08-31 16:14:03 -07001512 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -07001513}
1514
Tai Lya4d748b2023-03-28 22:06:56 +00001515template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001516OpMaxPool2d<Dtype>::~OpMaxPool2d()
1517{
1518 if (attribute)
1519 delete attribute;
1520}
1521
Tai Lya4d748b2023-03-28 22:06:56 +00001522template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001523int OpMaxPool2d<Dtype>::checkTensorAttributes()
1524{
1525 if (validateRequiredOperands())
1526 return 1;
1527
1528 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
1529 {
1530 return 1;
1531 }
1532
1533 if (inputs[0]->matchType(*outputs[0]))
1534 {
1535 printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
1536 return 1;
1537 }
1538
1539 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1540 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1541
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001542 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +00001543 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001544 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001545 msg = "OpMaxPool2d: " + msg;
1546 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001547 return 1;
1548 }
1549
1550 return 0;
1551}
1552
Tai Lya4d748b2023-03-28 22:06:56 +00001553template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001554int OpMaxPool2d<Dtype>::eval()
1555{
1556 int in_batch = this->in->getShape()[0];
1557 int in_height = this->in->getShape()[1];
1558 int in_width = this->in->getShape()[2];
1559 int in_channels = this->in->getShape()[3];
1560
1561 int out_batch = this->out->getShape()[0];
1562 int out_height = this->out->getShape()[1];
1563 int out_width = this->out->getShape()[2];
1564 int out_channels = this->out->getShape()[3];
1565
Kevin Chengacb550f2021-06-29 15:32:19 -07001566 ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1567 ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001568
TatWai Chong86c403b2022-06-06 20:46:01 -07001569 int pad_top = this->attribute->pad()[0];
1570 int pad_bottom = this->attribute->pad()[1];
1571 int pad_left = this->attribute->pad()[2];
1572 int pad_right = this->attribute->pad()[3];
1573
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001574 int kernel_y = this->attribute->kernel()[0];
1575 int kernel_x = this->attribute->kernel()[1];
1576 int stride_y = this->attribute->stride()[0];
1577 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001578
1579 // Check Tosa Level
1580 auto tosa_level = g_func_config.tosa_level;
1581 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
1582 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
1583 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1584 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
1585 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1586 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1587 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1588 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -07001589
1590 DEBUG_INFO(OP,
1591 "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
TatWai Chong86c403b2022-06-06 20:46:01 -07001592 "stride=[%d,%d], pad=[%d,%d,%d,%d]",
Jerry Gea793f462023-04-11 00:05:02 +00001593 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
1594 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001595
1596 Eigen::array<Eigen::Index, 2> im2col_input_dims;
Jerry Gea793f462023-04-11 00:05:02 +00001597 im2col_input_dims[0] = kernel_y * kernel_x;
Eric Kunzee5e26762020-10-13 16:11:07 -07001598 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
1599
1600 Eigen::array<Eigen::Index, 4> col2im_output_dims;
1601 col2im_output_dims[0] = out_batch;
1602 col2im_output_dims[1] = out_height;
1603 col2im_output_dims[2] = out_width;
1604 col2im_output_dims[3] = out_channels;
1605
TatWai Chong86c403b2022-06-06 20:46:01 -07001606 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1607 pad[0] = std::make_pair(0, 0);
1608 pad[1] = std::make_pair(pad_top, pad_bottom);
1609 pad[2] = std::make_pair(pad_left, pad_right);
1610 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001611
TatWai Chong86c403b2022-06-06 20:46:01 -07001612 ETensor4<InEigenType> input_padded = this->in->getTensor().pad(pad, std::numeric_limits<InEigenType>::lowest());
Eric Kunzee5e26762020-10-13 16:11:07 -07001613
1614 // extract_image_patches() output [N, KH, KW, H * W, C]
1615 // transpose to [KH, KW, N, H * W, C]
1616 // reshape to [KH * KW, N * H * W * C]
1617 //
1618 // Set the padding value to be the most negative value that can be
1619 // represented by the datatype to ensure that any padding values will be equal
1620 // to or smaller than the actual maximum in the KH x KW patch.
1621 ETensor2<InEigenType> input_extract_patches =
1622 input_padded
Jerry Gea793f462023-04-11 00:05:02 +00001623 .extract_image_patches(kernel_y, kernel_x, stride_y, stride_x, 1, 1, Eigen::PADDING_VALID,
Eric Kunzee5e26762020-10-13 16:11:07 -07001624 std::numeric_limits<InEigenType>::lowest())
1625 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
1626 .reshape(im2col_input_dims);
1627
1628 // Get the maximum of the KHxHW patches along axis 0
1629 Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
1630
1631 // 1D result with [N * H * W * C]
1632 ETensor1<OutEigenType> out_1d(this->out->getElementCount());
1633
1634 // index input_patches with argmax array should give the result
1635 for (size_t i = 0; i < this->out->getElementCount(); i++)
1636 {
1637 out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
1638 }
1639
1640 // reshape result to [N, H, W, C]
1641 this->out->getTensor() = out_1d.reshape(col2im_output_dims);
1642
1643 return GraphNode::eval();
1644}
1645
Tai Lya4d748b2023-03-28 22:06:56 +00001646template <TOSA_REF_TYPE Dtype>
1647OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton57287132023-02-06 14:54:18 +00001648 : GraphNode(sgt_, Op_FFT2D, id_)
1649{
1650 setRequiredOperands(2, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001651 setRequiredRank(3, 3);
Luke Hutton57287132023-02-06 14:54:18 +00001652
1653 INIT_ATTRIBUTE(FFT);
1654}
1655
Tai Lya4d748b2023-03-28 22:06:56 +00001656template <TOSA_REF_TYPE Dtype>
1657OpFFT2d<Dtype>::~OpFFT2d()
1658{
Luke Hutton57287132023-02-06 14:54:18 +00001659 if (attribute)
1660 delete attribute;
1661}
1662
Tai Lya4d748b2023-03-28 22:06:56 +00001663template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001664int OpFFT2d<Dtype>::checkTensorAttributes()
1665{
1666 if (validateRequiredOperands())
1667 return 1;
1668
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001669 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]) ||
1670 validateRequiredRank(outputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001671 {
1672 return 1;
1673 }
1674
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001675 if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) || inputs[0]->matchType(*inputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001676 {
1677 printNodeValidationError("OpFFT2d: input and output tensor type mismatch");
1678 return 1;
1679 }
1680
1681 in_real = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1682 in_imag = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
1683 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1684 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1685
1686 ASSERT_MEM(in_real && in_imag && out_real && out_imag);
1687
1688 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001689 if (check_fft_shape(in_real->getShape(), in_imag->getShape(), out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton57287132023-02-06 14:54:18 +00001690 {
1691 msg = "OpFFT2d: " + msg;
1692 printNodeValidationError(msg.c_str());
1693 return 1;
1694 }
1695
1696 return 0;
1697}
1698
Tai Lya4d748b2023-03-28 22:06:56 +00001699template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001700int OpFFT2d<Dtype>::eval()
1701{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001702 int in_real_batch = this->in_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001703 int in_real_height = this->in_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001704 int in_real_width = this->in_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001705
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001706 int in_imag_batch = this->in_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001707 int in_imag_height = this->in_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001708 int in_imag_width = this->in_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001709
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001710 int out_real_batch = this->out_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001711 int out_real_height = this->out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001712 int out_real_width = this->out_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001713
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001714 int out_imag_batch = this->out_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001715 int out_imag_height = this->out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001716 int out_imag_width = this->out_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001717
Jerry Gea793f462023-04-11 00:05:02 +00001718 // Check Tosa Level
1719 auto tosa_level = g_func_config.tosa_level;
1720 LEVEL_CHECK(in_real_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1721 LEVEL_CHECK(in_real_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1722
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001723 DEBUG_INFO(OP, "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]",
1724 in_real_batch, in_real_height, in_real_width, in_imag_batch, in_imag_height, in_imag_width,
1725 out_real_batch, out_real_height, out_real_width, out_imag_batch, out_imag_height, out_imag_width);
Luke Hutton57287132023-02-06 14:54:18 +00001726
1727 OutEigenType sum_real, sum_imag, a, sign_val = 1.0;
1728
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001729 if (attribute->inverse())
1730 {
Luke Hutton57287132023-02-06 14:54:18 +00001731 sign_val = -1.0;
1732 }
1733
Tai Ly307392a2023-05-12 21:42:19 +00001734 TIn in_real_val = this->in_real->getTensor();
1735 TIn in_imag_val = this->in_imag->getTensor();
1736
1737 if (g_func_config.abs_mode)
1738 {
1739 // in abs_mode: take abs values of real and imag operands
1740 in_real_val = in_real_val.abs();
1741 in_imag_val = in_imag_val.abs();
1742 }
1743
Luke Hutton57287132023-02-06 14:54:18 +00001744 for (int n = 0; n < in_real_batch; n++)
1745 {
1746 for (int oy = 0; oy < out_real_height; oy++)
1747 {
1748 for (int ox = 0; ox < out_real_width; ox++)
1749 {
1750 sum_real = 0.0;
1751 sum_imag = 0.0;
1752 for (int iy = 0; iy < in_real_height; iy++)
1753 {
1754 for (int ix = 0; ix < in_real_width; ix++)
1755 {
Tai Ly307392a2023-05-12 21:42:19 +00001756 OutEigenType val_real = in_real_val(n, iy, ix);
1757 OutEigenType val_imag = in_imag_val(n, iy, ix);
Luke Hutton57287132023-02-06 14:54:18 +00001758 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001759 a = sign_val * 2 * M_PI *
1760 ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width);
Luke Hutton57287132023-02-06 14:54:18 +00001761 sum_real += val_real * cos(a) + val_imag * sin(a);
1762 sum_imag += -val_real * sin(a) + val_imag * cos(a);
1763 }
1764 }
1765 this->out_real->getTensor()(n, oy, ox) = sum_real;
1766 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1767 }
1768 }
1769 }
1770
1771 return GraphNode::eval();
1772}
1773
Tai Lya4d748b2023-03-28 22:06:56 +00001774template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001775OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton261b7b62023-01-10 14:50:31 +00001776 : GraphNode(sgt_, Op_RFFT2D, id_)
1777{
1778 setRequiredOperands(1, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001779 setRequiredRank(3, 3);
Luke Hutton261b7b62023-01-10 14:50:31 +00001780}
1781
Tai Lya4d748b2023-03-28 22:06:56 +00001782template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001783OpRFFT2d<Dtype>::~OpRFFT2d()
1784{}
Luke Hutton261b7b62023-01-10 14:50:31 +00001785
Tai Lya4d748b2023-03-28 22:06:56 +00001786template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001787int OpRFFT2d<Dtype>::checkTensorAttributes()
1788{
1789 if (validateRequiredOperands())
1790 return 1;
1791
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001792 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1]))
Luke Hutton261b7b62023-01-10 14:50:31 +00001793 {
1794 return 1;
1795 }
1796
1797 if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
1798 {
1799 printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
1800 return 1;
1801 }
1802
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001803 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
Luke Hutton261b7b62023-01-10 14:50:31 +00001804 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1805 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1806
1807 ASSERT_MEM(in && out_real && out_imag);
1808
Luke Hutton57287132023-02-06 14:54:18 +00001809 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001810 if (check_fft_shape(in->getShape(), {}, out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton261b7b62023-01-10 14:50:31 +00001811 {
Luke Hutton57287132023-02-06 14:54:18 +00001812 msg = "OpRFFT2d: " + msg;
1813 printNodeValidationError(msg.c_str());
Luke Hutton261b7b62023-01-10 14:50:31 +00001814 return 1;
1815 }
1816
1817 return 0;
1818}
1819
Tai Lya4d748b2023-03-28 22:06:56 +00001820template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001821int OpRFFT2d<Dtype>::eval()
1822{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001823 int32_t in_batch = in->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001824 int32_t in_height = in->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001825 int32_t in_width = in->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001826
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001827 int32_t out_real_batch = out_real->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001828 int32_t out_real_height = out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001829 int32_t out_real_width = out_real->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001830
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001831 int32_t out_imag_batch = out_imag->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001832 int32_t out_imag_height = out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001833 int32_t out_imag_width = out_imag->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001834
Jerry Gea793f462023-04-11 00:05:02 +00001835 // Check Tosa Level
1836 auto tosa_level = g_func_config.tosa_level;
1837 LEVEL_CHECK(in_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1838 LEVEL_CHECK(in_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1839
Luke Hutton261b7b62023-01-10 14:50:31 +00001840 DEBUG_INFO(OP,
1841 "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], "
1842 "output_imag.shape=[%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001843 in_batch, in_height, in_width, out_real_batch, out_real_height, out_real_width, out_imag_batch,
1844 out_imag_height, out_imag_width);
Luke Hutton261b7b62023-01-10 14:50:31 +00001845
1846 OutEigenType sum_real, sum_imag, a;
1847
Tai Ly307392a2023-05-12 21:42:19 +00001848 TIn in_val = this->in->getTensor();
1849
1850 if (g_func_config.abs_mode)
1851 {
1852 // in abs_mode: take abs values of in operand
1853 in_val = in_val.abs();
1854 }
1855
Luke Hutton261b7b62023-01-10 14:50:31 +00001856 for (int n = 0; n < in_batch; n++)
1857 {
1858 for (int oy = 0; oy < out_real_height; oy++)
1859 {
1860 for (int ox = 0; ox < out_real_width; ox++)
1861 {
1862 sum_real = 0.0;
1863 sum_imag = 0.0;
1864 for (int iy = 0; iy < in_height; iy++)
1865 {
1866 for (int ix = 0; ix < in_width; ix++)
1867 {
1868 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
1869 a = 2 * M_PI * ((iy * (OutEigenType)oy) / in_height + (ix * (OutEigenType)ox) / in_width);
Tai Ly307392a2023-05-12 21:42:19 +00001870 sum_real += in_val(n, iy, ix) * cos(a);
1871 sum_imag += -in_val(n, iy, ix) * sin(a);
Luke Hutton261b7b62023-01-10 14:50:31 +00001872 }
1873 }
1874 this->out_real->getTensor()(n, oy, ox) = sum_real;
1875 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1876 }
1877 }
1878 }
1879
1880 return GraphNode::eval();
1881}
1882
Tai Lya4d748b2023-03-28 22:06:56 +00001883template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001884OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001885 TosaAttributeBase* attribute_,
1886 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001887 : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001888{
1889 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001890 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001891
Kevin Cheng93a16282021-08-31 16:14:03 -07001892 INIT_ATTRIBUTE(TransposeConv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001893}
1894
Tai Lya4d748b2023-03-28 22:06:56 +00001895template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001896OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001897{
1898 if (attribute)
1899 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001900}
1901
Tai Lya4d748b2023-03-28 22:06:56 +00001902template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001903int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001904{
1905 if (validateRequiredOperands())
1906 return 1;
1907
1908 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1909 {
1910 return 1;
1911 }
1912
James Wardd34b3fc2023-01-18 14:51:25 +00001913 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001914 "OpTransposeConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001915
Eric Kunzee5e26762020-10-13 16:11:07 -07001916 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1917 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1918 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001919 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001920
TatWai Chong24594f52022-06-08 00:48:04 -07001921 if (attribute->out_pad().size() != 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001922 {
TatWai Chong24594f52022-06-08 00:48:04 -07001923 printNodeValidationError("OpTransposeConv2d: illegal size for attribute out_pad");
Eric Kunzee5e26762020-10-13 16:11:07 -07001924 return 1;
1925 }
1926
1927 if (attribute->stride().size() != 2)
1928 {
1929 printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
1930 return 1;
1931 }
1932
Eric Kunzee5e26762020-10-13 16:11:07 -07001933 if (attribute->output_shape().size() != 4)
1934 {
1935 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1936 return 1;
1937 }
1938
Kevin Cheng9fe17242021-11-10 01:04:39 +00001939 for (int32_t i : attribute->stride())
1940 {
1941 if (i < 1)
1942 {
1943 printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one");
1944 return 1;
1945 }
1946 }
1947
Eric Kunzee5e26762020-10-13 16:11:07 -07001948 for (int d = 0; d < 4; d++)
1949 {
1950 if (attribute->output_shape()[d] != this->output->getShape()[d])
1951 {
1952 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1953 return 1;
1954 }
1955 }
1956
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001957 int32_t IH = input->getShape()[1];
1958 int32_t IW = input->getShape()[2];
1959 int32_t OH = output->getShape()[1];
1960 int32_t OW = output->getShape()[2];
1961
1962 int32_t stride_y = attribute->stride()[0];
1963 int32_t stride_x = attribute->stride()[1];
1964 int32_t kernel_h = weight->getShape()[1];
1965 int32_t kernel_w = weight->getShape()[2];
1966
TatWai Chong24594f52022-06-08 00:48:04 -07001967 int32_t out_pad_top = attribute->out_pad()[0];
1968 int32_t out_pad_bottom = attribute->out_pad()[1];
1969 int32_t out_pad_left = attribute->out_pad()[2];
1970 int32_t out_pad_right = attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001971
Eric Kunzec1a97832022-07-01 16:56:09 -07001972 for (size_t i = 0; i < attribute->out_pad().size(); i++)
1973 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001974 ERROR_IF(attribute->out_pad()[i] <= -(weight->getShape()[(i / 2) + 1]),
1975 "OpTransposeConv2d: At least one out_pad value is larger than kernel size");
Eric Kunzec1a97832022-07-01 16:56:09 -07001976 }
1977
1978 int32_t H = (IH - 1) * stride_y + out_pad_top + out_pad_bottom + kernel_h;
1979 int32_t W = (IW - 1) * stride_x + out_pad_left + out_pad_right + kernel_w;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001980
1981 if ((OH != H) || (OW != W))
1982 {
1983 std::string msg = "OpTransposeConv2d: Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001984 std::to_string(H) + "," + std::to_string(W) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001985 printNodeValidationError(msg.c_str());
1986 return 1;
1987 }
1988
Tai Lya4d748b2023-03-28 22:06:56 +00001989 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
1990 "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
1991 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
1992 "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001993
Eric Kunzee5e26762020-10-13 16:11:07 -07001994 return 0;
1995}
1996
Tai Lya4d748b2023-03-28 22:06:56 +00001997template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001998int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001999{
2000 int in_batch = this->input->getShape()[0];
2001 int in_height = this->input->getShape()[1];
2002 int in_width = this->input->getShape()[2];
2003 int in_channels = this->input->getShape()[3];
2004
2005 int f_out_channels = this->weight->getShape()[0];
2006 int f_height = this->weight->getShape()[1];
2007 int f_width = this->weight->getShape()[2];
2008 int f_in_channels = this->weight->getShape()[3];
2009
2010 int b_out_channels = this->bias->getShape()[0];
2011
2012 int out_batch = this->output->getShape()[0];
2013 int out_height = this->output->getShape()[1];
2014 int out_width = this->output->getShape()[2];
2015 int out_channels = this->output->getShape()[3];
2016
TatWai Chong24594f52022-06-08 00:48:04 -07002017 int out_pad_top = this->attribute->out_pad()[0];
2018 int out_pad_bottom = this->attribute->out_pad()[1];
2019 int out_pad_left = this->attribute->out_pad()[2];
2020 int out_pad_right = this->attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002021
Jerry Gea793f462023-04-11 00:05:02 +00002022 int stride_y = this->attribute->stride()[0];
2023 int stride_x = this->attribute->stride()[1];
Eric Kunzee5e26762020-10-13 16:11:07 -07002024
Kevin Chengacb550f2021-06-29 15:32:19 -07002025 ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
2026 ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
2027 in_channels);
2028 ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
2029 f_out_channels, out_channels);
2030 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
2031 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07002032
Jerry Gea793f462023-04-11 00:05:02 +00002033 // Check Tosa Level
2034 auto tosa_level = g_func_config.tosa_level;
2035 LEVEL_CHECK(f_height <= tosa_level.MAX_KERNEL, "KH should be smaller than or equal to MAX_KERNEL");
2036 LEVEL_CHECK(f_width <= tosa_level.MAX_KERNEL, "KW should be smaller than or equal to MAX_KERNEL");
2037 LEVEL_CHECK(out_pad_top <= tosa_level.MAX_KERNEL, "out_pad_top should be smaller than or equal to MAX_KERNEL");
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002038 LEVEL_CHECK(out_pad_bottom <= tosa_level.MAX_KERNEL,
2039 "out_pad_bottom should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00002040 LEVEL_CHECK(out_pad_left <= tosa_level.MAX_KERNEL, "out_pad_left should be smaller than or equal to MAX_KERNEL");
2041 LEVEL_CHECK(out_pad_right <= tosa_level.MAX_KERNEL, "out_pad_right should be smaller than or equal to MAX_KERNEL");
2042 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
2043 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
2044
Eric Kunzee5e26762020-10-13 16:11:07 -07002045 DEBUG_INFO(OP,
2046 "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00002047 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002048 in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch,
2049 out_height, out_width, out_channels, stride_y, stride_x, out_pad_top, out_pad_bottom, out_pad_left,
2050 out_pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07002051
2052 TIn input_val = this->input->getTensor();
2053 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00002054 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002055 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002056 input_val = input_val - (InEigenType)attribute->input_zp();
2057 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07002058 }
2059
Tai Ly307392a2023-05-12 21:42:19 +00002060 TBias bias_val = this->bias->getTensor();
2061
2062 if (g_func_config.abs_mode)
2063 {
2064 // in abs_mode: take abs values of conv operands
2065 input_val = input_val.abs();
2066 weight_val = weight_val.abs();
2067 bias_val = bias_val.abs();
2068 }
2069
Eric Kunzee5e26762020-10-13 16:11:07 -07002070 Eigen::array<Eigen::Index, 4> reshape_dim;
2071 reshape_dim.fill(1);
2072 reshape_dim[3] = b_out_channels;
2073
2074 Eigen::array<Eigen::Index, 4> bcast;
2075 bcast[0] = out_batch;
2076 bcast[1] = out_height;
2077 bcast[2] = out_width;
2078 bcast[3] = 1;
2079
2080 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00002081 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07002082
2083 int out_x_origin, out_y_origin;
2084 int out_x, out_y;
2085
2086 // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
2087 for (int ob = 0; ob < out_batch; ob++)
2088 {
2089 for (int ih = 0; ih < in_height; ih++)
2090 {
2091 for (int iw = 0; iw < in_width; iw++)
2092 {
Jerry Gea793f462023-04-11 00:05:02 +00002093 out_x_origin = iw * stride_x + out_pad_left;
2094 out_y_origin = ih * stride_y + out_pad_top;
Eric Kunzee5e26762020-10-13 16:11:07 -07002095 for (int ic = 0; ic < in_channels; ic++)
2096 {
2097 for (int fh = 0; fh < f_height; fh++)
2098 {
2099 for (int fw = 0; fw < f_width; fw++)
2100 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002101 out_x = out_x_origin + fw;
2102 out_y = out_y_origin + fh;
Eric Kunzee5e26762020-10-13 16:11:07 -07002103 for (int oc = 0; oc < out_channels; oc++)
2104 {
2105 if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
2106 {
2107 this->output->getTensor()(ob, out_y, out_x, oc) +=
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002108 (OutEigenType)((AccEigenType)input_val(ob, ih, iw, ic) *
2109 (AccEigenType)weight_val(oc, fh, fw, ic));
Eric Kunzee5e26762020-10-13 16:11:07 -07002110 }
2111 }
2112 }
2113 }
2114 }
2115 }
2116 }
2117 }
2118
Tai Lya4d748b2023-03-28 22:06:56 +00002119 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07002120 {
James Ward8b390432022-08-12 20:48:56 +01002121 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
2122 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07002123 }
2124
2125 return GraphNode::eval();
2126}
2127
2128// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +01002129DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002130DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002131DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002132DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002133DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002134DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002135
James Wardd34b3fc2023-01-18 14:51:25 +00002136DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
2137DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
2138DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
2139DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
2140DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
2141DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +00002142DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002143
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002144// [in_t, weight_t, out_t]
James Wardd34b3fc2023-01-18 14:51:25 +00002145DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
2146DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
2147DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
2148DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
2149DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
2150DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
2151DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002152DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002153
James Wardd34b3fc2023-01-18 14:51:25 +00002154DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
2155DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
2156DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
2157DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
2158DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
2159DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
2160DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002161DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
Kevin Cheng1533b852021-09-01 12:51:58 -07002162
James Wardd34b3fc2023-01-18 14:51:25 +00002163DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
2164DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
2165DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
2166DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
2167DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
2168DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
2169DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002170DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002171
Luke Hutton57287132023-02-06 14:54:18 +00002172DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002173DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
Luke Hutton57287132023-02-06 14:54:18 +00002174
James Wardd34b3fc2023-01-18 14:51:25 +00002175DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
2176DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
2177DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
2178DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
2179DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
2180DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
2181DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002182DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002183
James Wardd34b3fc2023-01-18 14:51:25 +00002184DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
2185DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
2186DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
2187DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
2188DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
2189DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002190DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002191
James Ward8b390432022-08-12 20:48:56 +01002192DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002193DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002194DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002195DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002196DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002197DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002198
Luke Hutton261b7b62023-01-10 14:50:31 +00002199DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002200DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
Luke Hutton261b7b62023-01-10 14:50:31 +00002201
James Wardd34b3fc2023-01-18 14:51:25 +00002202DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
2203DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
2204DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
2205DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
2206DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
2207DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
2208DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002209DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);