blob: acdeeebdf03c2d452237ab0b3e84108589b3fa3a [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Hutton261b7b62023-01-10 14:50:31 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tensor_ops.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000017#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070018#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
Kevin Cheng9fe17242021-11-10 01:04:39 +000025int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
26 std::vector<int32_t> input_shape,
27 std::vector<int32_t> output_shape,
28 std::string& msg)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000029{
TatWai Chong86c403b2022-06-06 20:46:01 -070030 if (attribute->pad().size() != 4)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000031 {
32 msg = "illegal size for attribute padding";
33 return 1;
34 }
35
36 if (attribute->kernel().size() != 2)
37 {
38 msg = "illegal size for attribute kernel";
39 return 1;
40 }
41
42 if (attribute->stride().size() != 2)
43 {
44 msg = "illegal size for attribute stride";
45 return 1;
46 }
47
TatWai Chong86c403b2022-06-06 20:46:01 -070048 for (int32_t i : attribute->pad())
Kevin Cheng7eb93d72021-10-09 01:26:08 +000049 {
50 if (i < 0)
51 {
52 msg = "At least one pad is smaller than zero";
53 return 1;
54 }
55 }
56
57 for (int32_t i : attribute->kernel())
58 {
59 if (i < 1)
60 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000061 msg = "At least one kernel dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000062 return 1;
63 }
64 }
65
66 for (int32_t i : attribute->stride())
67 {
68 if (i < 1)
69 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000070 msg = "At least one stride dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000071 return 1;
72 }
73 }
74
75 int32_t IH = input_shape[1];
76 int32_t IW = input_shape[2];
77 int32_t OH = output_shape[1];
78 int32_t OW = output_shape[2];
79
TatWai Chong86c403b2022-06-06 20:46:01 -070080 int32_t pad_top = attribute->pad()[0];
81 int32_t pad_bottom = attribute->pad()[1];
82 int32_t pad_left = attribute->pad()[2];
83 int32_t pad_right = attribute->pad()[3];
Kevin Cheng7eb93d72021-10-09 01:26:08 +000084
85 int32_t stride_y = attribute->stride()[0];
86 int32_t stride_x = attribute->stride()[1];
87 int32_t kernel_y = attribute->kernel()[0];
88 int32_t kernel_x = attribute->kernel()[1];
89
90 if (pad_top >= kernel_y || pad_bottom >= kernel_y || pad_left >= kernel_x || pad_right >= kernel_x)
91 {
92 msg = "At least one pad is >= kernel dimension";
93 return 1;
94 }
95
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010096 int32_t full_H = IH + pad_top + pad_bottom - kernel_y;
97 int32_t full_W = IW + pad_left + pad_right - kernel_x;
98
Jerry Ge9c9c8da2023-07-19 23:08:16 +000099 if ((full_H % stride_y != 0) || (full_W % stride_x != 0))
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000100 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100101 msg = "Parameters must yield exact integer output dimensions";
102 return 1;
103 }
104
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000105 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100106 {
107 msg = "Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000108 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000109 return 1;
110 }
111
112 return 0;
113}
114
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000115int check_conv_attribute(tosa::TosaConvAttribute* attribute,
Tai Lya4d748b2023-03-28 22:06:56 +0000116 uint32_t conv_dimension,
117 std::vector<int32_t> input_shape,
118 std::vector<int32_t> output_shape,
119 std::vector<int32_t> weights,
120 uint32_t offset_kernel,
121 TOSA_REF_TYPE InDtype,
122 TOSA_REF_TYPE WeightDtype,
123 std::string& msg)
Kevin Cheng9fe17242021-11-10 01:04:39 +0000124{
TatWai Chong86c403b2022-06-06 20:46:01 -0700125 if (attribute->pad().size() != (2 * conv_dimension))
Kevin Cheng9fe17242021-11-10 01:04:39 +0000126 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700127 msg = "Illegal size for attribute pad";
Kevin Cheng9fe17242021-11-10 01:04:39 +0000128 return 1;
129 }
130
131 if (attribute->stride().size() != conv_dimension)
132 {
133 msg = "Illegal size for attribute stride";
134 return 1;
135 }
136
137 if (attribute->dilation().size() != conv_dimension)
138 {
139 msg = "Illegal size for attribute dilation";
140 return 1;
141 }
142
TatWai Chong86c403b2022-06-06 20:46:01 -0700143 for (int32_t i : attribute->pad())
Kevin Cheng9fe17242021-11-10 01:04:39 +0000144 {
145 if (i < 0)
146 {
147 msg = "At least one pad is smaller than zero";
148 return 1;
149 }
150 }
151
152 for (int32_t i : attribute->stride())
153 {
154 if (i < 1)
155 {
156 msg = "At least one stride dimension is smaller than one";
157 return 1;
158 }
159 }
160
161 for (int32_t i : attribute->dilation())
162 {
163 if (i < 1)
164 {
165 msg = "At least one dilation dimension is smaller than one";
166 return 1;
167 }
168 }
169
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100170 ASSERT_MSG(conv_dimension == 2 || conv_dimension == 3, "Unsupported convolution dimension")
171
TatWai Chongfd629052022-07-25 04:01:58 +0000172 int32_t offset_d = conv_dimension == 3 ? 1 : 0;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000173 int32_t ID = conv_dimension == 3 ? input_shape[1] : 1;
174 int32_t IH = input_shape[1 + offset_d];
175 int32_t IW = input_shape[2 + offset_d];
176 int32_t OD = conv_dimension == 3 ? output_shape[1] : 1;
177 int32_t OH = output_shape[1 + offset_d];
178 int32_t OW = output_shape[2 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100179
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000180 int32_t stride_d = conv_dimension == 3 ? attribute->stride()[0] : 1;
181 int32_t stride_y = attribute->stride()[0 + offset_d];
182 int32_t stride_x = attribute->stride()[1 + offset_d];
183 int32_t kernel_d = conv_dimension == 3 ? weights[offset_kernel] : 1;
184 int32_t kernel_h = weights[offset_kernel + offset_d];
185 int32_t kernel_w = weights[offset_kernel + 1 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100186 int32_t dilation_d = conv_dimension == 3 ? attribute->dilation()[0] : 1;
187 int32_t dilation_y = attribute->dilation()[0 + offset_d];
188 int32_t dilation_x = attribute->dilation()[1 + offset_d];
189
190 offset_d *= 2;
TatWai Chong86c403b2022-06-06 20:46:01 -0700191 int32_t pad_d0 = conv_dimension == 3 ? attribute->pad()[0] : 0;
192 int32_t pad_d1 = conv_dimension == 3 ? attribute->pad()[1] : 0;
193 int32_t pad_top = attribute->pad()[0 + offset_d];
194 int32_t pad_bottom = attribute->pad()[1 + offset_d];
195 int32_t pad_left = attribute->pad()[2 + offset_d];
196 int32_t pad_right = attribute->pad()[3 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100197
198 int32_t full_D = ID - 1 + pad_d0 + pad_d1 - (kernel_d - 1) * dilation_d;
199 int32_t full_H = IH - 1 + pad_top + pad_bottom - (kernel_h - 1) * dilation_y;
200 int32_t full_W = IW - 1 + pad_left + pad_right - (kernel_w - 1) * dilation_x;
201
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000202 if ((full_H % stride_y != 0) || (full_W % stride_x != 0) || (full_D % stride_d != 0))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100203 {
204 msg = "Parameters must yield exact integer output dimensions";
205 return 1;
206 }
207
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000208 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1) || (OD != (full_D / stride_d) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100209 {
210 std::string msg_d = "";
211 if (conv_dimension == 3)
212 {
213 msg_d += std::to_string((full_D / stride_d) + 1) + ",";
214 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000215 msg = "Mismatch between output shape provided and expected output shape (" + msg_d +
216 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100217 return 1;
218 }
219
Tai Lya4d748b2023-03-28 22:06:56 +0000220 if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0)
221 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000222 msg = "Input zero point must be zero for non-int8 data";
223 return 1;
224 }
Tai Lya4d748b2023-03-28 22:06:56 +0000225 if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0)
226 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 msg = "Weight zero point must be zero for non-int8 data";
228 return 1;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000229 }
230
231 return 0;
232}
233
Luke Hutton57287132023-02-06 14:54:18 +0000234int check_fft_shape(const std::vector<int32_t>& in_real,
235 const std::vector<int32_t>& in_imag,
236 const std::vector<int32_t>& out_real,
237 const std::vector<int32_t>& out_imag,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000238 std::string& msg)
239{
240 const bool is_rfft = in_imag.empty();
241 auto is_power_of_two = [](int32_t n) -> bool { return (n & (n - 1)) == 0 && n > 0; };
Luke Hutton57287132023-02-06 14:54:18 +0000242
243 if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2]))
244 {
245 msg = "Input height and width must be a power of two";
246 return 1;
247 }
248
249 // RFFT does not have a second input
250 if (!is_rfft)
251 {
252 bool input_check = true;
253 for (size_t i = 0; i < in_real.size(); i++)
254 {
255 if (in_real[i] != in_imag[i])
256 {
257 input_check = false;
258 break;
259 }
260 }
261 if (!input_check)
262 {
263 msg = "Mismatch between real input shape and imaginary input shape";
264 return 1;
265 }
266 }
267
268 bool output_check = true;
269 for (size_t i = 0; i < out_real.size(); i++)
270 {
271 if (out_real[i] != out_imag[i])
272 {
273 output_check = false;
274 break;
275 }
276 }
277 if (!output_check)
278 {
279 msg = "Mismatch between real output shape and imaginary output shape";
280 return 1;
281 }
282
283 if (in_real[0] != out_real[0])
284 {
285 msg = "Input and output batch size don't match";
286 return 1;
287 }
288 if (in_real[1] != out_real[1])
289 {
290 msg = "Input and output height don't match";
291 return 1;
292 }
293
294 if (is_rfft)
295 {
296 if (in_real[2] / 2 + 1 != out_real[2])
297 {
298 msg = "Output width is expected to match input width / 2 + 1";
299 return 1;
300 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000301 }
302 else
303 {
Luke Hutton57287132023-02-06 14:54:18 +0000304 if (in_real[2] != out_real[2])
305 {
306 msg = "Input and output width don't match";
307 return 1;
308 }
309 }
310
311 return 0;
312}
313
Tai Lya4d748b2023-03-28 22:06:56 +0000314template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000315OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700316 : GraphNode(sgt_, Op_ARGMAX, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700317{
318 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000319 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
321 INIT_ATTRIBUTE(Axis);
322}
323
Tai Lya4d748b2023-03-28 22:06:56 +0000324template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700325OpArgMax<Rank, Dtype>::~OpArgMax()
326{
327 if (attribute)
328 delete attribute;
329}
330
Tai Lya4d748b2023-03-28 22:06:56 +0000331template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700332int OpArgMax<Rank, Dtype>::checkTensorAttributes()
333{
334 if (validateRequiredOperands())
335 return 1;
336
Kevin Chengcc61be32021-10-14 17:09:57 -0700337 if (validateRequiredRank(inputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700338 {
339 return 1;
340 }
341
Kevin Chengcc61be32021-10-14 17:09:57 -0700342 int32_t output_rank = inputs[0]->getRank() - 1;
343 if (output_rank != outputs[0]->getRank())
344 {
345 printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
346 return 1;
347 }
348
Tai Lya4d748b2023-03-28 22:06:56 +0000349 if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32)
Kevin Chengcc61be32021-10-14 17:09:57 -0700350 {
351 printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
352 return 1;
353 }
354
Eric Kunzee5e26762020-10-13 16:11:07 -0700355 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
356 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
357
Kevin Chengcc61be32021-10-14 17:09:57 -0700358 if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
359 {
360 printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
361 return 1;
362 }
363
364 bool shape_check = true;
365 for (int32_t i = 0; i < input->getRank(); i++)
366 {
367 if (i < attribute->axis())
368 {
369 if (input->getShape()[i] != output->getShape()[i])
370 {
371 shape_check = false;
372 break;
373 }
374 }
375 else if (i > attribute->axis())
376 {
377 if (input->getShape()[i] != output->getShape()[i - 1])
378 {
379 shape_check = false;
380 break;
381 }
382 }
383 // No need to check i == axis
384 }
385 if (!shape_check)
386 {
387 printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
388 return 1;
389 }
390
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 return 0;
392}
393
Tai Lya4d748b2023-03-28 22:06:56 +0000394template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700395int OpArgMax<Rank, Dtype>::eval()
396{
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000397 // Check Tosa Level
398 auto tosa_level = g_func_config.tosa_level;
399 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
400
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
402
403 this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
404
405 return GraphNode::eval();
406}
407
Tai Lya4d748b2023-03-28 22:06:56 +0000408template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000409OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700410 : GraphNode(sgt_, Op_AVG_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700411{
412 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000413 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700414
Kevin Cheng93a16282021-08-31 16:14:03 -0700415 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -0700416}
417
Tai Lya4d748b2023-03-28 22:06:56 +0000418template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100419OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700420{
421 if (attribute)
422 delete attribute;
423}
424
Tai Lya4d748b2023-03-28 22:06:56 +0000425template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100426int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700427{
428 if (validateRequiredOperands())
429 return 1;
430
431 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
432 {
433 return 1;
434 }
435
436 if (inputs[0]->matchType(*outputs[0]))
437 {
438 printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
439 return 1;
440 }
441
442 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
443 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
444
Tai Lya4d748b2023-03-28 22:06:56 +0000445 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
446 "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
447 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0,
448 "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
Eric Kunzee5e26762020-10-13 16:11:07 -0700449
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000450 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000451 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700452 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000453 msg = "OpAvgPool2d: " + msg;
454 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return 1;
456 }
457
458 return 0;
459}
460
Eric Kunze830add42022-01-25 22:56:46 -0800461// This calculates the number of padding elements used for each location along an axis
462// Average pooling only divides by the number of elements used, not including padding.
463// This function uses left/right, but is also used for vertical padding with top/bottom
Tai Lya4d748b2023-03-28 22:06:56 +0000464template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
465ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(
466 int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
Eric Kunzee5e26762020-10-13 16:11:07 -0700467{
468 ETensor1<int32_t> result(out_size);
469
Eric Kunzee5e26762020-10-13 16:11:07 -0700470 result.setConstant(kernel_size);
471
Eric Kunze830add42022-01-25 22:56:46 -0800472 // adjust divisors on the left side for padding
473 // We start at the leftmost output element, and remove pad_left - (index * stride) elements
474 // until we have no more padding being used
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000475 for (int index = 0; (index <= pad_left / stride) && (index < out_size); index++)
476 {
Eric Kunze830add42022-01-25 22:56:46 -0800477 int32_t adjust = pad_left - (index * stride);
478 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 }
480
Eric Kunze830add42022-01-25 22:56:46 -0800481 // The process repeats on the right side. Padding starts taking effect as we
482 // near the rightmost input element. The first output element which touches
483 // padding is defined in the initialization of index below. Then we keep moving
484 // to the right, increasing padding until we get to the last output element.
485 int index = std::max(0, ((pad_left + in_size - kernel_size) / stride) + 1);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000486 for (; index < out_size; index++)
487 {
Eric Kunze830add42022-01-25 22:56:46 -0800488 int32_t adjust = ((index * stride) + kernel_size) - (pad_left + in_size);
489 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700490 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700491 return result;
492}
493
494// assuming input and output tensor have same scales like tflite reference
495// so no need to scale input and output
Tai Lya4d748b2023-03-28 22:06:56 +0000496template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100497int OpAvgPool2d<Dtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700498{
499 int in_batch = this->in->getShape()[0];
500 int in_height = this->in->getShape()[1];
501 int in_width = this->in->getShape()[2];
502 int in_channels = this->in->getShape()[3];
503
504 int out_batch = this->out->getShape()[0];
505 int out_height = this->out->getShape()[1];
506 int out_width = this->out->getShape()[2];
507 int out_channels = this->out->getShape()[3];
508
Kevin Chengacb550f2021-06-29 15:32:19 -0700509 ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
510 ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700511
TatWai Chong86c403b2022-06-06 20:46:01 -0700512 int pad_top = this->attribute->pad()[0];
513 int pad_bottom = this->attribute->pad()[1];
514 int pad_left = this->attribute->pad()[2];
515 int pad_right = this->attribute->pad()[3];
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000516 int kernel_y = this->attribute->kernel()[0];
517 int kernel_x = this->attribute->kernel()[1];
518 int stride_y = this->attribute->stride()[0];
519 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000520
521 // Check Tosa Level
522 auto tosa_level = g_func_config.tosa_level;
523 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
524 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
525 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
526 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
527 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
528 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
529 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
530 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -0700531
Tai Lya4d748b2023-03-28 22:06:56 +0000532 TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype());
James Ward8b390432022-08-12 20:48:56 +0100533
Eric Kunzee5e26762020-10-13 16:11:07 -0700534 DEBUG_INFO(OP,
535 "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
James Ward8b390432022-08-12 20:48:56 +0100536 "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
Jerry Gea793f462023-04-11 00:05:02 +0000537 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
538 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
540 Eigen::array<Eigen::Index, 2> im2col_input_dims;
Jerry Gea793f462023-04-11 00:05:02 +0000541 im2col_input_dims[0] = kernel_y * kernel_x;
Eric Kunzee5e26762020-10-13 16:11:07 -0700542 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
543
544 Eigen::array<Eigen::Index, 4> col2im_output_dims;
545 col2im_output_dims[0] = out_batch;
546 col2im_output_dims[1] = out_height;
547 col2im_output_dims[2] = out_width;
548 col2im_output_dims[3] = out_channels;
549
TatWai Chong86c403b2022-06-06 20:46:01 -0700550 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
551 pad[0] = std::make_pair(0, 0);
552 pad[1] = std::make_pair(pad_top, pad_bottom);
553 pad[2] = std::make_pair(pad_left, pad_right);
554 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700555
556 ETensor4<InEigenType> input_val = this->in->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000557 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700558 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000559 input_val = input_val - (InEigenType)attribute->input_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700560 }
561
TatWai Chong86c403b2022-06-06 20:46:01 -0700562 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700563
Tai Ly307392a2023-05-12 21:42:19 +0000564 if (g_func_config.abs_mode)
565 {
566 // in abs_mode: take abs values of input_padded
567 input_padded = input_padded.abs();
568 }
569
Eric Kunzee5e26762020-10-13 16:11:07 -0700570 // assuming input and output have same scales
571 // so input and output scaling is not required
572 // TODO: check if this assumption TOSA made
573
574 // extract_image_patches() output [N, KH, KW, H * W, C]
575 // transpose to [KH, KW, N, H * W, C]
576 // reshape to [KH * KW, N * H * W * C]
577 ETensor2<InEigenType> input_extract_patches =
Jerry Gea793f462023-04-11 00:05:02 +0000578 input_padded.extract_image_patches(kernel_y, kernel_x, stride_y, stride_x, 1, 1, Eigen::PADDING_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700579 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
580 .reshape(im2col_input_dims);
581
582 // 1D result with [N * H * W * C]
583 ETensor1<AccEigenType> out_1d(this->out->getElementCount());
584 out_1d.setZero();
585
586 // sum pool
587 for (size_t i = 0; i < this->out->getElementCount(); i++)
588 {
Jerry Gea793f462023-04-11 00:05:02 +0000589 for (int32_t j = 0; j < kernel_y * kernel_x; j++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700590 {
591 out_1d(i) += (AccEigenType)input_extract_patches(j, i);
592 }
593 }
594
595 // reshape result to [N, H, W, C] and divide with div_map
596 ETensor4<AccEigenType> sum = out_1d.reshape(col2im_output_dims);
597
598 // calculate 1d height/width div_map (number of elements this pooling window covers)
599 // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C]
Jeremy Johnson44eb88d2023-04-24 09:49:58 +0100600 ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_y, stride_y, pad_top, pad_bottom);
Jerry Gea793f462023-04-11 00:05:02 +0000601 ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_x, stride_x, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700602 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
603 Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
604
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000605 ETensor2<int32_t> dm2_w = div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width });
606 ETensor2<int32_t> dm2_h = div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 });
607 ETensor4<int32_t> div_map = dm2_h.contract(dm2_w, contract_dims)
608 .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
609 .broadcast(bcast);
Tai Lya4d748b2023-03-28 22:06:56 +0000610 if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
611 Dtype != TOSA_REF_TYPE_FP64)
Eric Kunzee5e26762020-10-13 16:11:07 -0700612 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700613 try
614 {
615 this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
616 int32_t multiplier, shift;
617 TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift);
Eric Kunzee5e26762020-10-13 16:11:07 -0700618
Kevin Chengacb550f2021-06-29 15:32:19 -0700619 return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false);
620 });
621 }
622 catch (std::string desc)
623 {
624 REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
625 }
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000626 this->out->getTensor() = this->out->getTensor() + (OutEigenType)(attribute->output_zp());
Eric Kunzee5e26762020-10-13 16:11:07 -0700627 this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
628 this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
629 }
630 else
631 {
James Ward24dbc422022-10-19 12:20:31 +0100632 // Case for float-types
Eric Kunzee5e26762020-10-13 16:11:07 -0700633 this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
634 }
635
636 return GraphNode::eval();
637}
638
Tai Lya4d748b2023-03-28 22:06:56 +0000639template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000640OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700641 : GraphNode(sgt_, Op_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700642{
643 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000644 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700645
Kevin Cheng93a16282021-08-31 16:14:03 -0700646 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700647}
648
Tai Lya4d748b2023-03-28 22:06:56 +0000649template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000650OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700651{
652 if (attribute)
653 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700654}
655
Tai Lya4d748b2023-03-28 22:06:56 +0000656template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000657int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700658{
659 if (validateRequiredOperands())
660 return 1;
661
662 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
663 {
664 return 1;
665 }
666
667 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
668 if (inputs[2]->getRank() != 1)
669 {
670 printNodeValidationError("OpConv2d: bias tensor must be rank 1");
671 }
672
James Wardd34b3fc2023-01-18 14:51:25 +0000673 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000674 "OpConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700675
Eric Kunzee5e26762020-10-13 16:11:07 -0700676 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
677 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
678 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100679 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700680
Kevin Cheng9fe17242021-11-10 01:04:39 +0000681 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000682 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000683 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700684 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000685 msg = "OpConv2d: " + msg;
686 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700687 return 1;
688 }
689
Eric Kunzee5e26762020-10-13 16:11:07 -0700690 return 0;
691}
692
Tai Lya4d748b2023-03-28 22:06:56 +0000693template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000694int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700695{
696 int in_batch = this->input->getShape()[0];
697 int in_height = this->input->getShape()[1];
698 int in_width = this->input->getShape()[2];
699 int in_channels = this->input->getShape()[3];
700
701 int f_out_channels = this->weight->getShape()[0];
702 int f_height = this->weight->getShape()[1];
703 int f_width = this->weight->getShape()[2];
704 int f_in_channels = this->weight->getShape()[3];
705
706 int b_out_channels = this->bias->getShape()[0];
707
708 int out_batch = this->output->getShape()[0];
709 int out_height = this->output->getShape()[1];
710 int out_width = this->output->getShape()[2];
711 int out_channels = this->output->getShape()[3];
712
Kevin Chengacb550f2021-06-29 15:32:19 -0700713 ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
714 ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
715 in_channels);
716 ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
717 out_channels);
718 ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700719
TatWai Chong86c403b2022-06-06 20:46:01 -0700720 int pad_top = this->attribute->pad()[0];
721 int pad_bottom = this->attribute->pad()[1];
722 int pad_left = this->attribute->pad()[2];
723 int pad_right = this->attribute->pad()[3];
724
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000725 int stride_y = this->attribute->stride()[0];
726 int stride_x = this->attribute->stride()[1];
727 int dilation_y = this->attribute->dilation()[0];
728 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000729
730 // Check Tosa Level
731 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000732 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
733 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
734 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
735 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000736 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
737 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
738 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
739 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
740 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
741 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -0700742
743 DEBUG_INFO(OP,
744 "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000745 "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -0700746 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000747 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
748 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700749
750 // GEMM-conv2d, left matrix is input, right matrix is weight
751 Eigen::array<Eigen::Index, 2> im2col_input_dims;
752 im2col_input_dims[0] = out_batch * out_height * out_width;
753 im2col_input_dims[1] = f_height * f_width * f_in_channels;
754
755 Eigen::array<Eigen::Index, 2> im2col_weight_dims;
756 im2col_weight_dims[0] = f_height * f_width * f_in_channels;
757 im2col_weight_dims[1] = f_out_channels;
758
759 Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
760 bias_reshaped_dims[0] = 1;
761 bias_reshaped_dims[1] = b_out_channels;
762
763 Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
764 weight_zp_bcast_dims[0] = f_height;
765 weight_zp_bcast_dims[1] = f_width;
766 weight_zp_bcast_dims[2] = f_in_channels;
767
768 Eigen::array<Eigen::Index, 2> bias_bcast_dims;
769 bias_bcast_dims[0] = out_batch * out_height * out_width;
770 bias_bcast_dims[1] = 1;
771
772 Eigen::array<Eigen::Index, 4> col2im_output_dims;
773 col2im_output_dims[0] = out_batch;
774 col2im_output_dims[1] = out_height;
775 col2im_output_dims[2] = out_width;
776 col2im_output_dims[3] = out_channels;
777
778 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
779
TatWai Chong86c403b2022-06-06 20:46:01 -0700780 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
781 pad[0] = std::make_pair(0, 0);
782 pad[1] = std::make_pair(pad_top, pad_bottom);
783 pad[2] = std::make_pair(pad_left, pad_right);
784 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700785
786 TIn input_val = this->input->getTensor();
787 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000788 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700789 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000790 input_val = input_val - (InEigenType)attribute->input_zp();
791 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700792 }
793
TatWai Chong86c403b2022-06-06 20:46:01 -0700794 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700795
Tai Ly307392a2023-05-12 21:42:19 +0000796 TBias bias_val = this->bias->getTensor();
797
798 if (g_func_config.abs_mode)
799 {
800 // in abs_mode: take abs values of conv operands
801 input_padded = input_padded.abs();
802 weight_val = weight_val.abs();
803 bias_val = bias_val.abs();
804 }
805
Eric Kunzee5e26762020-10-13 16:11:07 -0700806 // extract_image_patches() output [N, KH, KW, H * W, C]
807 // need to transpose to [N, H * W, KH, KW, C]
808 ETensor5<InEigenType> input_extract_patches =
809 input_padded
Jerry Gea793f462023-04-11 00:05:02 +0000810 .extract_image_patches(f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700811 .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
812
813 // reshape input to [N * H * W, KH * KW * C]
814 ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
815
816 // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
817 ETensor2<WeightEigenType> im2col_weight =
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000818 weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
Eric Kunzee5e26762020-10-13 16:11:07 -0700819
820 // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
821 // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
Tai Ly307392a2023-05-12 21:42:19 +0000822 ETensor2<OutEigenType> bias_2d =
823 (bias_val.reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700824
825 // output matrix is [N * H * W, C]
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000826 ETensor2<OutEigenType> contracted_result = (im2col_input.template cast<AccEigenType>().contract(
827 im2col_weight.template cast<AccEigenType>(), contract_dims))
828 .template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700829
830 // adding bias
James Ward8b390432022-08-12 20:48:56 +0100831 ETensor2<OutEigenType> biased_output = contracted_result + bias_2d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700832
833 // reshape back to [N, H, W, C]
834 this->output->getTensor() = biased_output.reshape(col2im_output_dims);
835
Tai Lya4d748b2023-03-28 22:06:56 +0000836 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -0700837 {
James Ward8b390432022-08-12 20:48:56 +0100838 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
839 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700840 }
841
842 return GraphNode::eval();
843}
844
Tai Lya4d748b2023-03-28 22:06:56 +0000845template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000846OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Cheng1533b852021-09-01 12:51:58 -0700847 : GraphNode(sgt_, Op_CONV3D, id_)
848{
849 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000850 setRequiredRank(5, 5);
Kevin Cheng1533b852021-09-01 12:51:58 -0700851
852 INIT_ATTRIBUTE(Conv);
Kevin Cheng1533b852021-09-01 12:51:58 -0700853}
854
Tai Lya4d748b2023-03-28 22:06:56 +0000855template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000856OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
Kevin Cheng1533b852021-09-01 12:51:58 -0700857{
858 if (attribute)
859 delete attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700860}
861
Tai Lya4d748b2023-03-28 22:06:56 +0000862template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000863int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Kevin Cheng1533b852021-09-01 12:51:58 -0700864{
865 if (validateRequiredOperands())
866 return 1;
867
868 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
869 {
870 return 1;
871 }
872
873 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
874 if (inputs[2]->getRank() != 1)
875 {
876 printNodeValidationError("OpConv3d: bias tensor must be rank 1");
877 }
878
James Wardd34b3fc2023-01-18 14:51:25 +0000879 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000880 "OpConv3d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700881
Kevin Cheng1533b852021-09-01 12:51:58 -0700882 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
883 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
884 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100885 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng1533b852021-09-01 12:51:58 -0700886
Kevin Cheng9fe17242021-11-10 01:04:39 +0000887 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000888 if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000889 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Kevin Cheng1533b852021-09-01 12:51:58 -0700890 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000891 msg = "OpConv3d: " + msg;
892 printNodeValidationError(msg.c_str());
Kevin Cheng1533b852021-09-01 12:51:58 -0700893 return 1;
894 }
895
Kevin Cheng1533b852021-09-01 12:51:58 -0700896 return 0;
897}
898
Tai Lya4d748b2023-03-28 22:06:56 +0000899template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000900int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
Kevin Cheng1533b852021-09-01 12:51:58 -0700901{
902 int in_batch = this->input->getShape()[0];
903 int in_depth = this->input->getShape()[1];
904 int in_height = this->input->getShape()[2];
905 int in_width = this->input->getShape()[3];
906 int in_channels = this->input->getShape()[4];
907
908 int f_out_channels = this->weight->getShape()[0];
909 int f_depth = this->weight->getShape()[1];
910 int f_height = this->weight->getShape()[2];
911 int f_width = this->weight->getShape()[3];
912 int f_in_channels = this->weight->getShape()[4];
913
914 int b_out_channels = this->bias->getShape()[0];
915
916 int out_batch = this->output->getShape()[0];
917 int out_depth = this->output->getShape()[1];
918 int out_height = this->output->getShape()[2];
919 int out_width = this->output->getShape()[3];
920 int out_channels = this->output->getShape()[4];
921
922 ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
923 ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
924 in_channels);
925 ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
926 out_channels);
927 ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels);
928
TatWai Chong86c403b2022-06-06 20:46:01 -0700929 int pad_d0 = this->attribute->pad()[0];
930 int pad_d1 = this->attribute->pad()[1];
931 int pad_top = this->attribute->pad()[2];
932 int pad_bottom = this->attribute->pad()[3];
933 int pad_left = this->attribute->pad()[4];
934 int pad_right = this->attribute->pad()[5];
935
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000936 int stride_d = this->attribute->stride()[0];
937 int stride_y = this->attribute->stride()[1];
938 int stride_x = this->attribute->stride()[2];
TatWai Chong86c403b2022-06-06 20:46:01 -0700939
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000940 int dilation_d = this->attribute->dilation()[0];
941 int dilation_y = this->attribute->dilation()[1];
942 int dilation_x = this->attribute->dilation()[2];
Jerry Gea793f462023-04-11 00:05:02 +0000943
944 // Check Tosa Level
945 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000946 LEVEL_CHECK(dilation_d * f_depth <= tosa_level.MAX_KERNEL,
947 "dilation_d * KD should be smaller than or equal to MAX_KERNEL");
948 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
949 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
950 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
951 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000952 LEVEL_CHECK(pad_d0 <= tosa_level.MAX_KERNEL, "pad_d0 should be smaller than or equal to MAX_KERNEL");
953 LEVEL_CHECK(pad_d1 <= tosa_level.MAX_KERNEL, "pad_d1 should be smaller than or equal to MAX_KERNEL");
954 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
955 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
956 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
957 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
958 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
959 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
960 LEVEL_CHECK(stride_d <= tosa_level.MAX_STRIDE, "stride_d should be smaller than or equal to MAX_STRIDE");
Kevin Cheng1533b852021-09-01 12:51:58 -0700961
962 DEBUG_INFO(
963 OP,
964 "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000965 "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
Kevin Cheng1533b852021-09-01 12:51:58 -0700966 in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
Jerry Gea793f462023-04-11 00:05:02 +0000967 out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_y, stride_x, dilation_d, dilation_y,
968 dilation_x, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
Kevin Cheng1533b852021-09-01 12:51:58 -0700969
TatWai Chong86c403b2022-06-06 20:46:01 -0700970 Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
971 pad[0] = std::make_pair(0, 0);
972 pad[1] = std::make_pair(pad_d0, pad_d1);
973 pad[2] = std::make_pair(pad_top, pad_bottom);
974 pad[3] = std::make_pair(pad_left, pad_right);
975 pad[4] = std::make_pair(0, 0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700976
977 TIn input_val = this->input->getTensor();
978 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000979 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Kevin Cheng1533b852021-09-01 12:51:58 -0700980 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000981 input_val = input_val - (InEigenType)attribute->input_zp();
982 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Kevin Cheng1533b852021-09-01 12:51:58 -0700983 }
984
TatWai Chong86c403b2022-06-06 20:46:01 -0700985 ETensor5<InEigenType> input_padded = input_val.pad(pad);
Kevin Cheng1533b852021-09-01 12:51:58 -0700986
Tai Ly307392a2023-05-12 21:42:19 +0000987 TBias bias_val = this->bias->getTensor();
988
989 if (g_func_config.abs_mode)
990 {
991 // in abs_mode: take abs values of conv operands
992 input_padded = input_padded.abs();
993 weight_val = weight_val.abs();
994 bias_val = bias_val.abs();
995 }
996
Kevin Cheng1533b852021-09-01 12:51:58 -0700997 // 1. initialize with bias
998 Eigen::array<Eigen::Index, 5> reshape_dim;
999 reshape_dim.fill(1);
1000 reshape_dim[4] = b_out_channels;
1001
1002 Eigen::array<Eigen::Index, 5> bcast;
1003 bcast[0] = out_batch;
1004 bcast[1] = out_depth;
1005 bcast[2] = out_height;
1006 bcast[3] = out_width;
1007 bcast[4] = 1;
Tai Ly307392a2023-05-12 21:42:19 +00001008 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Kevin Cheng1533b852021-09-01 12:51:58 -07001009
1010 // 2. direct convolution
James Ward8b390432022-08-12 20:48:56 +01001011 AccEigenType acc(0.0);
Kevin Cheng1533b852021-09-01 12:51:58 -07001012 int d_idx, h_idx, w_idx;
1013
1014 for (int ob = 0; ob < out_batch; ob++)
1015 {
1016 for (int od = 0; od < out_depth; od++)
1017 {
1018 for (int oh = 0; oh < out_height; oh++)
1019 {
1020 for (int ow = 0; ow < out_width; ow++)
1021 {
1022 for (int oc = 0; oc < out_channels; oc++)
1023 {
Eric Kunze7edb34c2022-05-16 17:34:40 -07001024 // Initialize accumulator with bias value
James Ward8b390432022-08-12 20:48:56 +01001025 acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc);
Kevin Cheng1533b852021-09-01 12:51:58 -07001026 for (int fd = 0; fd < f_depth; fd++)
1027 {
1028 d_idx = od * stride_d + fd * dilation_d;
1029 for (int fh = 0; fh < f_height; fh++)
1030 {
Jerry Gea793f462023-04-11 00:05:02 +00001031 h_idx = oh * stride_y + fh * dilation_y;
Kevin Cheng1533b852021-09-01 12:51:58 -07001032 for (int fw = 0; fw < f_width; fw++)
1033 {
Jerry Gea793f462023-04-11 00:05:02 +00001034 w_idx = ow * stride_x + fw * dilation_x;
Kevin Cheng1533b852021-09-01 12:51:58 -07001035 for (int ic = 0; ic < in_channels; ic++)
1036 {
1037 acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
1038 (AccEigenType)weight_val(oc, fd, fh, fw, ic));
1039 }
1040 }
1041 }
1042 }
James Ward8b390432022-08-12 20:48:56 +01001043 this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc;
Kevin Cheng1533b852021-09-01 12:51:58 -07001044 }
1045 }
1046 }
1047 }
1048 }
1049
Tai Lya4d748b2023-03-28 22:06:56 +00001050 if (OutDtype == TOSA_REF_TYPE_INT48)
Kevin Cheng1533b852021-09-01 12:51:58 -07001051 {
James Ward8b390432022-08-12 20:48:56 +01001052 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1053 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Kevin Cheng1533b852021-09-01 12:51:58 -07001054 }
1055
1056 return GraphNode::eval();
1057}
1058
Tai Lya4d748b2023-03-28 22:06:56 +00001059template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001060OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001061 TosaAttributeBase* attribute_,
1062 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001063 : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001064{
1065 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001066 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001067
Kevin Cheng93a16282021-08-31 16:14:03 -07001068 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001069}
1070
Tai Lya4d748b2023-03-28 22:06:56 +00001071template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001072OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001073{
1074 if (attribute)
1075 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001076}
1077
Tai Lya4d748b2023-03-28 22:06:56 +00001078template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001079int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001080{
1081 if (validateRequiredOperands())
1082 return 1;
1083
1084 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1085 {
1086 return 1;
1087 }
1088
1089 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
1090 if (inputs[2]->getRank() != 1)
1091 {
1092 printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
1093 }
1094
James Wardd34b3fc2023-01-18 14:51:25 +00001095 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001096 "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001097
Eric Kunzee5e26762020-10-13 16:11:07 -07001098 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1099 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1100 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001101 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001102
Kevin Cheng9fe17242021-11-10 01:04:39 +00001103 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001104 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001105 weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001106 {
Kevin Cheng9fe17242021-11-10 01:04:39 +00001107 msg = "OpDepthwiseConv2d: " + msg;
1108 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001109 return 1;
1110 }
1111
Eric Kunzee5e26762020-10-13 16:11:07 -07001112 return 0;
1113}
1114
Tai Lya4d748b2023-03-28 22:06:56 +00001115template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001116int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001117{
1118 int in_batch = this->input->getShape()[0];
1119 int in_height = this->input->getShape()[1];
1120 int in_width = this->input->getShape()[2];
1121 int in_channels = this->input->getShape()[3];
1122
1123 int f_height = this->weight->getShape()[0];
1124 int f_width = this->weight->getShape()[1];
1125 int f_in_channels = this->weight->getShape()[2];
1126 int f_multiplier = this->weight->getShape()[3];
1127
1128 int b_out_channels = this->bias->getShape()[0];
1129
1130 int out_batch = this->output->getShape()[0];
1131 int out_height = this->output->getShape()[1];
1132 int out_width = this->output->getShape()[2];
1133 int out_channels = this->output->getShape()[3];
1134
Kevin Chengacb550f2021-06-29 15:32:19 -07001135 ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1136 ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1137 in_channels);
1138 ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
1139 in_channels * f_multiplier, out_channels);
1140 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
1141 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001142
TatWai Chong86c403b2022-06-06 20:46:01 -07001143 int pad_top = this->attribute->pad()[0];
1144 int pad_bottom = this->attribute->pad()[1];
1145 int pad_left = this->attribute->pad()[2];
1146 int pad_right = this->attribute->pad()[3];
1147
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001148 int stride_y = this->attribute->stride()[0];
1149 int stride_x = this->attribute->stride()[1];
1150 int dilation_y = this->attribute->dilation()[0];
1151 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001152
1153 // Check Tosa Level
1154 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001155 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
1156 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
1157 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
1158 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00001159 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1160 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1161 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1162 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
1163 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1164 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -07001165
1166 DEBUG_INFO(OP,
1167 "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00001168 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -07001169 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001170 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
1171 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001172
TatWai Chong86c403b2022-06-06 20:46:01 -07001173 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1174 pad[0] = std::make_pair(0, 0);
1175 pad[1] = std::make_pair(pad_top, pad_bottom);
1176 pad[2] = std::make_pair(pad_left, pad_right);
1177 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001178
1179 TIn input_val = this->input->getTensor();
1180 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001181 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001182 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001183 input_val = input_val - (InEigenType)attribute->input_zp();
1184 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001185 }
1186
TatWai Chong86c403b2022-06-06 20:46:01 -07001187 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -07001188
Tai Ly307392a2023-05-12 21:42:19 +00001189 TBias bias_val = this->bias->getTensor();
1190
1191 if (g_func_config.abs_mode)
1192 {
1193 // in abs_mode: take abs values of conv operands
1194 input_padded = input_padded.abs();
1195 weight_val = weight_val.abs();
1196 bias_val = bias_val.abs();
1197 }
1198
Eric Kunzee5e26762020-10-13 16:11:07 -07001199 // GEMM doesn't fit well with DepthwiseConv2d
TatWai Chong86c403b2022-06-06 20:46:01 -07001200 // 1. use extract_image_patches() to handle stride/dilation/pad
Eric Kunzee5e26762020-10-13 16:11:07 -07001201 // 2. perform direct convolution
1202
1203 // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
1204 ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
Jerry Gea793f462023-04-11 00:05:02 +00001205 f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID);
Eric Kunzee5e26762020-10-13 16:11:07 -07001206
1207 Eigen::array<Eigen::Index, 4> reshape_dim;
1208 reshape_dim.fill(1);
1209 reshape_dim[3] = b_out_channels;
1210
1211 Eigen::array<Eigen::Index, 4> bcast;
1212 bcast[0] = out_batch;
1213 bcast[1] = out_height;
1214 bcast[2] = out_width;
1215 bcast[3] = 1;
1216
1217 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00001218 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001219
1220 // 2. direct depthwise convolution
1221 for (int ob = 0; ob < out_batch; ob++)
1222 {
1223 for (int oh = 0; oh < out_height; oh++)
1224 {
1225 for (int ow = 0; ow < out_width; ow++)
1226 {
1227 for (int ic = 0; ic < in_channels; ic++)
1228 {
1229 for (int cm = 0; cm < f_multiplier; cm++)
1230 {
1231 for (int fh = 0; fh < f_height; fh++)
1232 {
1233 for (int fw = 0; fw < f_width; fw++)
1234 {
James Ward8b390432022-08-12 20:48:56 +01001235 // Perform multiplication in AccEigenType then cast to OutEigenType
Eric Kunzebe2e87c2023-08-07 15:16:18 +00001236 this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
1237 (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh,
1238 ic) *
1239 (AccEigenType)weight_val(fh, fw, ic, cm));
Eric Kunzee5e26762020-10-13 16:11:07 -07001240 }
1241 }
1242 }
1243 }
1244 }
1245 }
1246 }
1247
Tai Lya4d748b2023-03-28 22:06:56 +00001248 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001249 {
James Ward8b390432022-08-12 20:48:56 +01001250 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1251 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001252 }
1253
1254 return GraphNode::eval();
1255}
1256
Tai Lya4d748b2023-03-28 22:06:56 +00001257template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001258OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001259 TosaAttributeBase* attribute_,
1260 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001261 : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001262{
1263 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001264 setRequiredRank(2, 2);
Eric Kunzee5e26762020-10-13 16:11:07 -07001265
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001266 INIT_ATTRIBUTE(FullyConnected);
Eric Kunzee5e26762020-10-13 16:11:07 -07001267}
1268
Tai Lya4d748b2023-03-28 22:06:56 +00001269template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001270OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
Eric Kunzee5e26762020-10-13 16:11:07 -07001271{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001272 if (attribute)
1273 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001274}
1275
Tai Lya4d748b2023-03-28 22:06:56 +00001276template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001277int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001278{
1279 if (validateRequiredOperands())
1280 return 1;
1281
1282 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1283 {
1284 return 1;
1285 }
1286
1287 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1288 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1289 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
1290
1291 if (input->getShape()[1] != weight->getShape()[1])
1292 {
1293 printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
1294 return 1;
1295 }
1296
1297 if (weight->getShape()[0] != bias->getShape()[0])
1298 {
1299 printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
1300 return 1;
1301 }
1302
James Wardd34b3fc2023-01-18 14:51:25 +00001303 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001304 "OpFullyConnected: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001305
James Ward8b390432022-08-12 20:48:56 +01001306 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001307
Tai Lya4d748b2023-03-28 22:06:56 +00001308 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
1309 "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
1310 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
1311 "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001312
Eric Kunzee5e26762020-10-13 16:11:07 -07001313 return 0;
1314}
1315
Tai Lya4d748b2023-03-28 22:06:56 +00001316template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001317int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001318{
1319 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1320 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1321
1322 Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
1323
1324 Eigen::array<Eigen::Index, 2> bias_reshape;
1325 bias_reshape[0] = 1;
1326 bias_reshape[1] = this->bias->getShape()[0];
1327
1328 Eigen::array<Eigen::Index, 2> bias_bcast;
1329 bias_bcast[0] = this->input->getShape()[0];
1330 bias_bcast[1] = 1;
1331
1332 TIn input_val = this->input->getTensor();
1333 TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
Tai Lya4d748b2023-03-28 22:06:56 +00001334 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001335 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001336 input_val = input_val - (InEigenType)attribute->input_zp();
1337 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001338 }
1339
Tai Ly307392a2023-05-12 21:42:19 +00001340 TBias bias_val = this->bias->getTensor();
1341
1342 if (g_func_config.abs_mode)
1343 {
1344 // in abs_mode: take abs values of conv operands
1345 input_val = input_val.abs();
1346 weight_val = weight_val.abs();
1347 bias_val = bias_val.abs();
1348 }
1349
1350 this->output->getTensor() = input_val.template cast<AccEigenType>()
1351 .contract(weight_val.template cast<AccEigenType>(), dims)
1352 .template cast<OutEigenType>() +
1353 bias_val.reshape(bias_reshape).broadcast(bias_bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001354
Tai Lya4d748b2023-03-28 22:06:56 +00001355 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001356 {
James Ward8b390432022-08-12 20:48:56 +01001357 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1358 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001359 }
1360 return GraphNode::eval();
1361}
1362
Tai Lya4d748b2023-03-28 22:06:56 +00001363template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001364OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001365 : GraphNode(sgt_, Op_MATMUL, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001366{
1367 setRequiredOperands(2, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001368 setRequiredRank(3, 3);
Eric Kunzee5e26762020-10-13 16:11:07 -07001369
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001370 INIT_ATTRIBUTE(MatMul);
Eric Kunzee5e26762020-10-13 16:11:07 -07001371}
1372
Tai Lya4d748b2023-03-28 22:06:56 +00001373template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001374OpMatMul<Dtype, OutDtype>::~OpMatMul()
Eric Kunzee5e26762020-10-13 16:11:07 -07001375{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001376 if (attribute)
1377 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001378}
1379
Tai Lya4d748b2023-03-28 22:06:56 +00001380template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001381int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001382{
1383 if (validateRequiredOperands())
1384 return 1;
1385
1386 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1387 {
1388 return 1;
1389 }
1390
James Wardd34b3fc2023-01-18 14:51:25 +00001391 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001392 "OpMatMul: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001393
Kevin Cheng2d60f002021-06-09 14:18:32 -07001394 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1395 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
James Ward8b390432022-08-12 20:48:56 +01001396 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001397
Kevin Cheng2d60f002021-06-09 14:18:32 -07001398 ASSERT_MEM(a && b && output);
1399
1400 // a: [N, H, C]
1401 // b: [N, C, W]
1402 // c: [N, H, W]
1403
1404 // Check N
1405 if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001406 {
Kevin Cheng2d60f002021-06-09 14:18:32 -07001407 printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
Eric Kunzee5e26762020-10-13 16:11:07 -07001408 return 1;
1409 }
Kevin Cheng2d60f002021-06-09 14:18:32 -07001410 N = a->getShape()[0];
Eric Kunzee5e26762020-10-13 16:11:07 -07001411
Kevin Cheng2d60f002021-06-09 14:18:32 -07001412 // Check C
1413 if (a->getShape()[2] != b->getShape()[1])
1414 {
1415 printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
1416 return 1;
1417 }
1418 C = a->getShape()[2];
1419
1420 // Check H
1421 if (a->getShape()[1] != output->getShape()[1])
1422 {
1423 printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
1424 return 1;
1425 }
1426 H = a->getShape()[1];
1427
1428 // Check W
1429 if (b->getShape()[2] != output->getShape()[2])
1430 {
1431 printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
1432 return 1;
1433 }
1434 W = b->getShape()[2];
Eric Kunzee5e26762020-10-13 16:11:07 -07001435
Tai Lya4d748b2023-03-28 22:06:56 +00001436 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0,
1437 "OpMatMul: A zeropoint must be zero for non int8_t data");
1438 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0,
1439 "OpMatMul: B zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001440
Eric Kunzee5e26762020-10-13 16:11:07 -07001441 return 0;
1442}
1443
Tai Lya4d748b2023-03-28 22:06:56 +00001444template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001445int OpMatMul<Dtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001446{
1447 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1448 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1449
1450 TIn a_val = this->a->getTensor();
1451 TIn b_val = this->b->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001452 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001453 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001454 a_val = a_val - (InEigenType)attribute->a_zp();
1455 b_val = b_val - (InEigenType)attribute->b_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001456 }
1457
Tai Ly307392a2023-05-12 21:42:19 +00001458 if (g_func_config.abs_mode)
1459 {
1460 // in abs_mode: take abs values of matmul operands
1461 a_val = a_val.abs();
1462 b_val = b_val.abs();
1463 }
1464
Kevin Cheng2d60f002021-06-09 14:18:32 -07001465 Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
1466 Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
1467 Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
1468
1469 Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
1470 Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
1471
1472 Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
1473 Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
1474
1475 // Iterate N dimension.
1476 for (int i = 0; i < N; i++)
1477 {
1478 a_begin_array[0] = i;
1479 b_begin_array[0] = i;
1480
1481 TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
1482 TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
1483 TAccRank2 output_rank2_val =
1484 a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
James Ward8b390432022-08-12 20:48:56 +01001485 TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast<OutEigenType>();
Kevin Cheng2d60f002021-06-09 14:18:32 -07001486 if (i == 0)
1487 {
1488 this->output->getTensor() = output_rank3_val;
1489 }
1490 else
1491 {
James Ward8b390432022-08-12 20:48:56 +01001492 TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001493 this->output->getTensor() = temp;
1494 }
1495 }
Eric Kunzee5e26762020-10-13 16:11:07 -07001496
Tai Lya4d748b2023-03-28 22:06:56 +00001497 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001498 {
James Ward8b390432022-08-12 20:48:56 +01001499 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1500 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001501 }
1502
1503 return GraphNode::eval();
1504}
1505
Tai Lya4d748b2023-03-28 22:06:56 +00001506template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001507OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001508 : GraphNode(sgt_, Op_MAX_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001509{
1510 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001511 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001512
Kevin Cheng93a16282021-08-31 16:14:03 -07001513 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -07001514}
1515
Tai Lya4d748b2023-03-28 22:06:56 +00001516template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001517OpMaxPool2d<Dtype>::~OpMaxPool2d()
1518{
1519 if (attribute)
1520 delete attribute;
1521}
1522
Tai Lya4d748b2023-03-28 22:06:56 +00001523template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001524int OpMaxPool2d<Dtype>::checkTensorAttributes()
1525{
1526 if (validateRequiredOperands())
1527 return 1;
1528
1529 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
1530 {
1531 return 1;
1532 }
1533
1534 if (inputs[0]->matchType(*outputs[0]))
1535 {
1536 printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
1537 return 1;
1538 }
1539
1540 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1541 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1542
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001543 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +00001544 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001545 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001546 msg = "OpMaxPool2d: " + msg;
1547 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001548 return 1;
1549 }
1550
1551 return 0;
1552}
1553
Tai Lya4d748b2023-03-28 22:06:56 +00001554template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001555int OpMaxPool2d<Dtype>::eval()
1556{
1557 int in_batch = this->in->getShape()[0];
1558 int in_height = this->in->getShape()[1];
1559 int in_width = this->in->getShape()[2];
1560 int in_channels = this->in->getShape()[3];
1561
1562 int out_batch = this->out->getShape()[0];
1563 int out_height = this->out->getShape()[1];
1564 int out_width = this->out->getShape()[2];
1565 int out_channels = this->out->getShape()[3];
1566
Kevin Chengacb550f2021-06-29 15:32:19 -07001567 ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1568 ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001569
TatWai Chong86c403b2022-06-06 20:46:01 -07001570 int pad_top = this->attribute->pad()[0];
1571 int pad_bottom = this->attribute->pad()[1];
1572 int pad_left = this->attribute->pad()[2];
1573 int pad_right = this->attribute->pad()[3];
1574
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001575 int kernel_y = this->attribute->kernel()[0];
1576 int kernel_x = this->attribute->kernel()[1];
1577 int stride_y = this->attribute->stride()[0];
1578 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001579
1580 // Check Tosa Level
1581 auto tosa_level = g_func_config.tosa_level;
1582 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
1583 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
1584 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1585 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
1586 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1587 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1588 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1589 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -07001590
1591 DEBUG_INFO(OP,
1592 "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
TatWai Chong86c403b2022-06-06 20:46:01 -07001593 "stride=[%d,%d], pad=[%d,%d,%d,%d]",
Jerry Gea793f462023-04-11 00:05:02 +00001594 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
1595 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001596
1597 Eigen::array<Eigen::Index, 2> im2col_input_dims;
Jerry Gea793f462023-04-11 00:05:02 +00001598 im2col_input_dims[0] = kernel_y * kernel_x;
Eric Kunzee5e26762020-10-13 16:11:07 -07001599 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
1600
1601 Eigen::array<Eigen::Index, 4> col2im_output_dims;
1602 col2im_output_dims[0] = out_batch;
1603 col2im_output_dims[1] = out_height;
1604 col2im_output_dims[2] = out_width;
1605 col2im_output_dims[3] = out_channels;
1606
TatWai Chong86c403b2022-06-06 20:46:01 -07001607 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1608 pad[0] = std::make_pair(0, 0);
1609 pad[1] = std::make_pair(pad_top, pad_bottom);
1610 pad[2] = std::make_pair(pad_left, pad_right);
1611 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001612
TatWai Chong86c403b2022-06-06 20:46:01 -07001613 ETensor4<InEigenType> input_padded = this->in->getTensor().pad(pad, std::numeric_limits<InEigenType>::lowest());
Eric Kunzee5e26762020-10-13 16:11:07 -07001614
1615 // extract_image_patches() output [N, KH, KW, H * W, C]
1616 // transpose to [KH, KW, N, H * W, C]
1617 // reshape to [KH * KW, N * H * W * C]
1618 //
1619 // Set the padding value to be the most negative value that can be
1620 // represented by the datatype to ensure that any padding values will be equal
1621 // to or smaller than the actual maximum in the KH x KW patch.
1622 ETensor2<InEigenType> input_extract_patches =
1623 input_padded
Jerry Gea793f462023-04-11 00:05:02 +00001624 .extract_image_patches(kernel_y, kernel_x, stride_y, stride_x, 1, 1, Eigen::PADDING_VALID,
Eric Kunzee5e26762020-10-13 16:11:07 -07001625 std::numeric_limits<InEigenType>::lowest())
1626 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
1627 .reshape(im2col_input_dims);
1628
1629 // Get the maximum of the KHxHW patches along axis 0
1630 Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
1631
1632 // 1D result with [N * H * W * C]
1633 ETensor1<OutEigenType> out_1d(this->out->getElementCount());
1634
1635 // index input_patches with argmax array should give the result
1636 for (size_t i = 0; i < this->out->getElementCount(); i++)
1637 {
1638 out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
1639 }
1640
1641 // reshape result to [N, H, W, C]
1642 this->out->getTensor() = out_1d.reshape(col2im_output_dims);
1643
1644 return GraphNode::eval();
1645}
1646
Tai Lya4d748b2023-03-28 22:06:56 +00001647template <TOSA_REF_TYPE Dtype>
1648OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton57287132023-02-06 14:54:18 +00001649 : GraphNode(sgt_, Op_FFT2D, id_)
1650{
1651 setRequiredOperands(2, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001652 setRequiredRank(3, 3);
Luke Hutton57287132023-02-06 14:54:18 +00001653
1654 INIT_ATTRIBUTE(FFT);
1655}
1656
Tai Lya4d748b2023-03-28 22:06:56 +00001657template <TOSA_REF_TYPE Dtype>
1658OpFFT2d<Dtype>::~OpFFT2d()
1659{
Luke Hutton57287132023-02-06 14:54:18 +00001660 if (attribute)
1661 delete attribute;
1662}
1663
Tai Lya4d748b2023-03-28 22:06:56 +00001664template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001665int OpFFT2d<Dtype>::checkTensorAttributes()
1666{
1667 if (validateRequiredOperands())
1668 return 1;
1669
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001670 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]) ||
1671 validateRequiredRank(outputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001672 {
1673 return 1;
1674 }
1675
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001676 if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) || inputs[0]->matchType(*inputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001677 {
1678 printNodeValidationError("OpFFT2d: input and output tensor type mismatch");
1679 return 1;
1680 }
1681
1682 in_real = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1683 in_imag = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
1684 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1685 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1686
1687 ASSERT_MEM(in_real && in_imag && out_real && out_imag);
1688
1689 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001690 if (check_fft_shape(in_real->getShape(), in_imag->getShape(), out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton57287132023-02-06 14:54:18 +00001691 {
1692 msg = "OpFFT2d: " + msg;
1693 printNodeValidationError(msg.c_str());
1694 return 1;
1695 }
1696
1697 return 0;
1698}
1699
Tai Lya4d748b2023-03-28 22:06:56 +00001700template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001701int OpFFT2d<Dtype>::eval()
1702{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001703 int in_real_batch = this->in_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001704 int in_real_height = this->in_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001705 int in_real_width = this->in_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001706
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001707 int in_imag_batch = this->in_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001708 int in_imag_height = this->in_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001709 int in_imag_width = this->in_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001710
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001711 int out_real_batch = this->out_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001712 int out_real_height = this->out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001713 int out_real_width = this->out_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001714
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001715 int out_imag_batch = this->out_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001716 int out_imag_height = this->out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001717 int out_imag_width = this->out_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001718
Jerry Gea793f462023-04-11 00:05:02 +00001719 // Check Tosa Level
1720 auto tosa_level = g_func_config.tosa_level;
1721 LEVEL_CHECK(in_real_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1722 LEVEL_CHECK(in_real_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1723
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001724 DEBUG_INFO(OP, "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]",
1725 in_real_batch, in_real_height, in_real_width, in_imag_batch, in_imag_height, in_imag_width,
1726 out_real_batch, out_real_height, out_real_width, out_imag_batch, out_imag_height, out_imag_width);
Luke Hutton57287132023-02-06 14:54:18 +00001727
1728 OutEigenType sum_real, sum_imag, a, sign_val = 1.0;
1729
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001730 if (attribute->inverse())
1731 {
Luke Hutton57287132023-02-06 14:54:18 +00001732 sign_val = -1.0;
1733 }
1734
Tai Ly307392a2023-05-12 21:42:19 +00001735 TIn in_real_val = this->in_real->getTensor();
1736 TIn in_imag_val = this->in_imag->getTensor();
1737
1738 if (g_func_config.abs_mode)
1739 {
1740 // in abs_mode: take abs values of real and imag operands
1741 in_real_val = in_real_val.abs();
1742 in_imag_val = in_imag_val.abs();
1743 }
1744
Luke Hutton57287132023-02-06 14:54:18 +00001745 for (int n = 0; n < in_real_batch; n++)
1746 {
1747 for (int oy = 0; oy < out_real_height; oy++)
1748 {
1749 for (int ox = 0; ox < out_real_width; ox++)
1750 {
1751 sum_real = 0.0;
1752 sum_imag = 0.0;
1753 for (int iy = 0; iy < in_real_height; iy++)
1754 {
1755 for (int ix = 0; ix < in_real_width; ix++)
1756 {
Tai Ly307392a2023-05-12 21:42:19 +00001757 OutEigenType val_real = in_real_val(n, iy, ix);
1758 OutEigenType val_imag = in_imag_val(n, iy, ix);
Luke Hutton57287132023-02-06 14:54:18 +00001759 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001760 a = sign_val * 2 * M_PI *
1761 ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width);
Luke Hutton57287132023-02-06 14:54:18 +00001762 sum_real += val_real * cos(a) + val_imag * sin(a);
1763 sum_imag += -val_real * sin(a) + val_imag * cos(a);
1764 }
1765 }
1766 this->out_real->getTensor()(n, oy, ox) = sum_real;
1767 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1768 }
1769 }
1770 }
1771
1772 return GraphNode::eval();
1773}
1774
Tai Lya4d748b2023-03-28 22:06:56 +00001775template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001776OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton261b7b62023-01-10 14:50:31 +00001777 : GraphNode(sgt_, Op_RFFT2D, id_)
1778{
1779 setRequiredOperands(1, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001780 setRequiredRank(3, 3);
Luke Hutton261b7b62023-01-10 14:50:31 +00001781}
1782
Tai Lya4d748b2023-03-28 22:06:56 +00001783template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001784OpRFFT2d<Dtype>::~OpRFFT2d()
1785{}
Luke Hutton261b7b62023-01-10 14:50:31 +00001786
Tai Lya4d748b2023-03-28 22:06:56 +00001787template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001788int OpRFFT2d<Dtype>::checkTensorAttributes()
1789{
1790 if (validateRequiredOperands())
1791 return 1;
1792
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001793 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1]))
Luke Hutton261b7b62023-01-10 14:50:31 +00001794 {
1795 return 1;
1796 }
1797
1798 if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
1799 {
1800 printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
1801 return 1;
1802 }
1803
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001804 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
Luke Hutton261b7b62023-01-10 14:50:31 +00001805 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1806 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1807
1808 ASSERT_MEM(in && out_real && out_imag);
1809
Luke Hutton57287132023-02-06 14:54:18 +00001810 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001811 if (check_fft_shape(in->getShape(), {}, out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton261b7b62023-01-10 14:50:31 +00001812 {
Luke Hutton57287132023-02-06 14:54:18 +00001813 msg = "OpRFFT2d: " + msg;
1814 printNodeValidationError(msg.c_str());
Luke Hutton261b7b62023-01-10 14:50:31 +00001815 return 1;
1816 }
1817
1818 return 0;
1819}
1820
Tai Lya4d748b2023-03-28 22:06:56 +00001821template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001822int OpRFFT2d<Dtype>::eval()
1823{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001824 int32_t in_batch = in->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001825 int32_t in_height = in->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001826 int32_t in_width = in->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001827
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001828 int32_t out_real_batch = out_real->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001829 int32_t out_real_height = out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001830 int32_t out_real_width = out_real->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001831
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001832 int32_t out_imag_batch = out_imag->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001833 int32_t out_imag_height = out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001834 int32_t out_imag_width = out_imag->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001835
Jerry Gea793f462023-04-11 00:05:02 +00001836 // Check Tosa Level
1837 auto tosa_level = g_func_config.tosa_level;
1838 LEVEL_CHECK(in_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1839 LEVEL_CHECK(in_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1840
Luke Hutton261b7b62023-01-10 14:50:31 +00001841 DEBUG_INFO(OP,
1842 "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], "
1843 "output_imag.shape=[%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001844 in_batch, in_height, in_width, out_real_batch, out_real_height, out_real_width, out_imag_batch,
1845 out_imag_height, out_imag_width);
Luke Hutton261b7b62023-01-10 14:50:31 +00001846
1847 OutEigenType sum_real, sum_imag, a;
1848
Tai Ly307392a2023-05-12 21:42:19 +00001849 TIn in_val = this->in->getTensor();
1850
1851 if (g_func_config.abs_mode)
1852 {
1853 // in abs_mode: take abs values of in operand
1854 in_val = in_val.abs();
1855 }
1856
Luke Hutton261b7b62023-01-10 14:50:31 +00001857 for (int n = 0; n < in_batch; n++)
1858 {
1859 for (int oy = 0; oy < out_real_height; oy++)
1860 {
1861 for (int ox = 0; ox < out_real_width; ox++)
1862 {
1863 sum_real = 0.0;
1864 sum_imag = 0.0;
1865 for (int iy = 0; iy < in_height; iy++)
1866 {
1867 for (int ix = 0; ix < in_width; ix++)
1868 {
1869 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
1870 a = 2 * M_PI * ((iy * (OutEigenType)oy) / in_height + (ix * (OutEigenType)ox) / in_width);
Tai Ly307392a2023-05-12 21:42:19 +00001871 sum_real += in_val(n, iy, ix) * cos(a);
1872 sum_imag += -in_val(n, iy, ix) * sin(a);
Luke Hutton261b7b62023-01-10 14:50:31 +00001873 }
1874 }
1875 this->out_real->getTensor()(n, oy, ox) = sum_real;
1876 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1877 }
1878 }
1879 }
1880
1881 return GraphNode::eval();
1882}
1883
Tai Lya4d748b2023-03-28 22:06:56 +00001884template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001885OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001886 TosaAttributeBase* attribute_,
1887 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001888 : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001889{
1890 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001891 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001892
Kevin Cheng93a16282021-08-31 16:14:03 -07001893 INIT_ATTRIBUTE(TransposeConv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001894}
1895
Tai Lya4d748b2023-03-28 22:06:56 +00001896template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001897OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001898{
1899 if (attribute)
1900 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001901}
1902
Tai Lya4d748b2023-03-28 22:06:56 +00001903template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001904int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001905{
1906 if (validateRequiredOperands())
1907 return 1;
1908
1909 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1910 {
1911 return 1;
1912 }
1913
James Wardd34b3fc2023-01-18 14:51:25 +00001914 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001915 "OpTransposeConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001916
Eric Kunzee5e26762020-10-13 16:11:07 -07001917 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1918 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1919 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001920 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001921
TatWai Chong24594f52022-06-08 00:48:04 -07001922 if (attribute->out_pad().size() != 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001923 {
TatWai Chong24594f52022-06-08 00:48:04 -07001924 printNodeValidationError("OpTransposeConv2d: illegal size for attribute out_pad");
Eric Kunzee5e26762020-10-13 16:11:07 -07001925 return 1;
1926 }
1927
1928 if (attribute->stride().size() != 2)
1929 {
1930 printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
1931 return 1;
1932 }
1933
Eric Kunzee5e26762020-10-13 16:11:07 -07001934 if (attribute->output_shape().size() != 4)
1935 {
1936 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1937 return 1;
1938 }
1939
Kevin Cheng9fe17242021-11-10 01:04:39 +00001940 for (int32_t i : attribute->stride())
1941 {
1942 if (i < 1)
1943 {
1944 printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one");
1945 return 1;
1946 }
1947 }
1948
Eric Kunzee5e26762020-10-13 16:11:07 -07001949 for (int d = 0; d < 4; d++)
1950 {
1951 if (attribute->output_shape()[d] != this->output->getShape()[d])
1952 {
1953 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1954 return 1;
1955 }
1956 }
1957
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001958 int32_t IH = input->getShape()[1];
1959 int32_t IW = input->getShape()[2];
1960 int32_t OH = output->getShape()[1];
1961 int32_t OW = output->getShape()[2];
1962
1963 int32_t stride_y = attribute->stride()[0];
1964 int32_t stride_x = attribute->stride()[1];
1965 int32_t kernel_h = weight->getShape()[1];
1966 int32_t kernel_w = weight->getShape()[2];
1967
TatWai Chong24594f52022-06-08 00:48:04 -07001968 int32_t out_pad_top = attribute->out_pad()[0];
1969 int32_t out_pad_bottom = attribute->out_pad()[1];
1970 int32_t out_pad_left = attribute->out_pad()[2];
1971 int32_t out_pad_right = attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001972
Eric Kunzec1a97832022-07-01 16:56:09 -07001973 for (size_t i = 0; i < attribute->out_pad().size(); i++)
1974 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001975 ERROR_IF(attribute->out_pad()[i] <= -(weight->getShape()[(i / 2) + 1]),
1976 "OpTransposeConv2d: At least one out_pad value is larger than kernel size");
Eric Kunzec1a97832022-07-01 16:56:09 -07001977 }
1978
1979 int32_t H = (IH - 1) * stride_y + out_pad_top + out_pad_bottom + kernel_h;
1980 int32_t W = (IW - 1) * stride_x + out_pad_left + out_pad_right + kernel_w;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001981
1982 if ((OH != H) || (OW != W))
1983 {
1984 std::string msg = "OpTransposeConv2d: Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001985 std::to_string(H) + "," + std::to_string(W) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001986 printNodeValidationError(msg.c_str());
1987 return 1;
1988 }
1989
Tai Lya4d748b2023-03-28 22:06:56 +00001990 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
1991 "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
1992 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
1993 "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001994
Eric Kunzee5e26762020-10-13 16:11:07 -07001995 return 0;
1996}
1997
Tai Lya4d748b2023-03-28 22:06:56 +00001998template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001999int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07002000{
2001 int in_batch = this->input->getShape()[0];
2002 int in_height = this->input->getShape()[1];
2003 int in_width = this->input->getShape()[2];
2004 int in_channels = this->input->getShape()[3];
2005
2006 int f_out_channels = this->weight->getShape()[0];
2007 int f_height = this->weight->getShape()[1];
2008 int f_width = this->weight->getShape()[2];
2009 int f_in_channels = this->weight->getShape()[3];
2010
2011 int b_out_channels = this->bias->getShape()[0];
2012
2013 int out_batch = this->output->getShape()[0];
2014 int out_height = this->output->getShape()[1];
2015 int out_width = this->output->getShape()[2];
2016 int out_channels = this->output->getShape()[3];
2017
TatWai Chong24594f52022-06-08 00:48:04 -07002018 int out_pad_top = this->attribute->out_pad()[0];
2019 int out_pad_bottom = this->attribute->out_pad()[1];
2020 int out_pad_left = this->attribute->out_pad()[2];
2021 int out_pad_right = this->attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002022
Jerry Gea793f462023-04-11 00:05:02 +00002023 int stride_y = this->attribute->stride()[0];
2024 int stride_x = this->attribute->stride()[1];
Eric Kunzee5e26762020-10-13 16:11:07 -07002025
Kevin Chengacb550f2021-06-29 15:32:19 -07002026 ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
2027 ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
2028 in_channels);
2029 ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
2030 f_out_channels, out_channels);
2031 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
2032 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07002033
Jerry Gea793f462023-04-11 00:05:02 +00002034 // Check Tosa Level
2035 auto tosa_level = g_func_config.tosa_level;
2036 LEVEL_CHECK(f_height <= tosa_level.MAX_KERNEL, "KH should be smaller than or equal to MAX_KERNEL");
2037 LEVEL_CHECK(f_width <= tosa_level.MAX_KERNEL, "KW should be smaller than or equal to MAX_KERNEL");
2038 LEVEL_CHECK(out_pad_top <= tosa_level.MAX_KERNEL, "out_pad_top should be smaller than or equal to MAX_KERNEL");
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002039 LEVEL_CHECK(out_pad_bottom <= tosa_level.MAX_KERNEL,
2040 "out_pad_bottom should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00002041 LEVEL_CHECK(out_pad_left <= tosa_level.MAX_KERNEL, "out_pad_left should be smaller than or equal to MAX_KERNEL");
2042 LEVEL_CHECK(out_pad_right <= tosa_level.MAX_KERNEL, "out_pad_right should be smaller than or equal to MAX_KERNEL");
2043 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
2044 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
2045
Eric Kunzee5e26762020-10-13 16:11:07 -07002046 DEBUG_INFO(OP,
2047 "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00002048 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002049 in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch,
2050 out_height, out_width, out_channels, stride_y, stride_x, out_pad_top, out_pad_bottom, out_pad_left,
2051 out_pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07002052
2053 TIn input_val = this->input->getTensor();
2054 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00002055 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002056 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002057 input_val = input_val - (InEigenType)attribute->input_zp();
2058 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07002059 }
2060
Tai Ly307392a2023-05-12 21:42:19 +00002061 TBias bias_val = this->bias->getTensor();
2062
2063 if (g_func_config.abs_mode)
2064 {
2065 // in abs_mode: take abs values of conv operands
2066 input_val = input_val.abs();
2067 weight_val = weight_val.abs();
2068 bias_val = bias_val.abs();
2069 }
2070
Eric Kunzee5e26762020-10-13 16:11:07 -07002071 Eigen::array<Eigen::Index, 4> reshape_dim;
2072 reshape_dim.fill(1);
2073 reshape_dim[3] = b_out_channels;
2074
2075 Eigen::array<Eigen::Index, 4> bcast;
2076 bcast[0] = out_batch;
2077 bcast[1] = out_height;
2078 bcast[2] = out_width;
2079 bcast[3] = 1;
2080
2081 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00002082 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07002083
2084 int out_x_origin, out_y_origin;
2085 int out_x, out_y;
2086
2087 // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
2088 for (int ob = 0; ob < out_batch; ob++)
2089 {
2090 for (int ih = 0; ih < in_height; ih++)
2091 {
2092 for (int iw = 0; iw < in_width; iw++)
2093 {
Jerry Gea793f462023-04-11 00:05:02 +00002094 out_x_origin = iw * stride_x + out_pad_left;
2095 out_y_origin = ih * stride_y + out_pad_top;
Eric Kunzee5e26762020-10-13 16:11:07 -07002096 for (int ic = 0; ic < in_channels; ic++)
2097 {
2098 for (int fh = 0; fh < f_height; fh++)
2099 {
2100 for (int fw = 0; fw < f_width; fw++)
2101 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002102 out_x = out_x_origin + fw;
2103 out_y = out_y_origin + fh;
Eric Kunzee5e26762020-10-13 16:11:07 -07002104 for (int oc = 0; oc < out_channels; oc++)
2105 {
2106 if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
2107 {
2108 this->output->getTensor()(ob, out_y, out_x, oc) +=
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002109 (OutEigenType)((AccEigenType)input_val(ob, ih, iw, ic) *
2110 (AccEigenType)weight_val(oc, fh, fw, ic));
Eric Kunzee5e26762020-10-13 16:11:07 -07002111 }
2112 }
2113 }
2114 }
2115 }
2116 }
2117 }
2118 }
2119
Tai Lya4d748b2023-03-28 22:06:56 +00002120 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07002121 {
James Ward8b390432022-08-12 20:48:56 +01002122 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
2123 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07002124 }
2125
2126 return GraphNode::eval();
2127}
2128
2129// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +01002130DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002131DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002132DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002133DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002134DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002135DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002136
James Wardd34b3fc2023-01-18 14:51:25 +00002137DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
2138DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
2139DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
2140DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
2141DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
2142DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +00002143DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002144
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002145// [in_t, weight_t, out_t]
James Wardd34b3fc2023-01-18 14:51:25 +00002146DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
2147DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
2148DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
2149DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
2150DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
2151DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
2152DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002153DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002154
James Wardd34b3fc2023-01-18 14:51:25 +00002155DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
2156DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
2157DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
2158DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
2159DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
2160DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
2161DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002162DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
Kevin Cheng1533b852021-09-01 12:51:58 -07002163
James Wardd34b3fc2023-01-18 14:51:25 +00002164DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
2165DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
2166DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
2167DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
2168DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
2169DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
2170DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002171DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002172
Luke Hutton57287132023-02-06 14:54:18 +00002173DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002174DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
Luke Hutton57287132023-02-06 14:54:18 +00002175
James Wardd34b3fc2023-01-18 14:51:25 +00002176DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
2177DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
2178DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
2179DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
2180DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
2181DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
2182DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002183DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002184
James Wardd34b3fc2023-01-18 14:51:25 +00002185DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
2186DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
2187DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
2188DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
2189DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
2190DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002191DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002192
James Ward8b390432022-08-12 20:48:56 +01002193DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002194DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002195DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002196DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002197DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002198DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002199
Luke Hutton261b7b62023-01-10 14:50:31 +00002200DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002201DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
Luke Hutton261b7b62023-01-10 14:50:31 +00002202
James Wardd34b3fc2023-01-18 14:51:25 +00002203DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
2204DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
2205DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
2206DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
2207DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
2208DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
2209DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002210DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);