blob: d9608b7bb293512cb1ec5111e8f4d46fac6f27f6 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Hutton261b7b62023-01-10 14:50:31 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tensor_ops.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000017#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070018#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
Kevin Cheng9fe17242021-11-10 01:04:39 +000025int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
26 std::vector<int32_t> input_shape,
27 std::vector<int32_t> output_shape,
28 std::string& msg)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000029{
TatWai Chong86c403b2022-06-06 20:46:01 -070030 if (attribute->pad().size() != 4)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000031 {
32 msg = "illegal size for attribute padding";
33 return 1;
34 }
35
36 if (attribute->kernel().size() != 2)
37 {
38 msg = "illegal size for attribute kernel";
39 return 1;
40 }
41
42 if (attribute->stride().size() != 2)
43 {
44 msg = "illegal size for attribute stride";
45 return 1;
46 }
47
TatWai Chong86c403b2022-06-06 20:46:01 -070048 for (int32_t i : attribute->pad())
Kevin Cheng7eb93d72021-10-09 01:26:08 +000049 {
50 if (i < 0)
51 {
52 msg = "At least one pad is smaller than zero";
53 return 1;
54 }
55 }
56
57 for (int32_t i : attribute->kernel())
58 {
59 if (i < 1)
60 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000061 msg = "At least one kernel dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000062 return 1;
63 }
64 }
65
66 for (int32_t i : attribute->stride())
67 {
68 if (i < 1)
69 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000070 msg = "At least one stride dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000071 return 1;
72 }
73 }
74
75 int32_t IH = input_shape[1];
76 int32_t IW = input_shape[2];
77 int32_t OH = output_shape[1];
78 int32_t OW = output_shape[2];
79
TatWai Chong86c403b2022-06-06 20:46:01 -070080 int32_t pad_top = attribute->pad()[0];
81 int32_t pad_bottom = attribute->pad()[1];
82 int32_t pad_left = attribute->pad()[2];
83 int32_t pad_right = attribute->pad()[3];
Kevin Cheng7eb93d72021-10-09 01:26:08 +000084
85 int32_t stride_y = attribute->stride()[0];
86 int32_t stride_x = attribute->stride()[1];
87 int32_t kernel_y = attribute->kernel()[0];
88 int32_t kernel_x = attribute->kernel()[1];
89
90 if (pad_top >= kernel_y || pad_bottom >= kernel_y || pad_left >= kernel_x || pad_right >= kernel_x)
91 {
92 msg = "At least one pad is >= kernel dimension";
93 return 1;
94 }
95
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010096 int32_t full_H = IH + pad_top + pad_bottom - kernel_y;
97 int32_t full_W = IW + pad_left + pad_right - kernel_x;
98
Jerry Ge9c9c8da2023-07-19 23:08:16 +000099 if ((full_H % stride_y != 0) || (full_W % stride_x != 0))
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000100 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100101 msg = "Parameters must yield exact integer output dimensions";
102 return 1;
103 }
104
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000105 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100106 {
107 msg = "Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000108 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000109 return 1;
110 }
111
112 return 0;
113}
114
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000115int check_conv_attribute(tosa::TosaConvAttribute* attribute,
Tai Lya4d748b2023-03-28 22:06:56 +0000116 uint32_t conv_dimension,
117 std::vector<int32_t> input_shape,
118 std::vector<int32_t> output_shape,
119 std::vector<int32_t> weights,
120 uint32_t offset_kernel,
121 TOSA_REF_TYPE InDtype,
122 TOSA_REF_TYPE WeightDtype,
123 std::string& msg)
Kevin Cheng9fe17242021-11-10 01:04:39 +0000124{
TatWai Chong86c403b2022-06-06 20:46:01 -0700125 if (attribute->pad().size() != (2 * conv_dimension))
Kevin Cheng9fe17242021-11-10 01:04:39 +0000126 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700127 msg = "Illegal size for attribute pad";
Kevin Cheng9fe17242021-11-10 01:04:39 +0000128 return 1;
129 }
130
131 if (attribute->stride().size() != conv_dimension)
132 {
133 msg = "Illegal size for attribute stride";
134 return 1;
135 }
136
137 if (attribute->dilation().size() != conv_dimension)
138 {
139 msg = "Illegal size for attribute dilation";
140 return 1;
141 }
142
TatWai Chong86c403b2022-06-06 20:46:01 -0700143 for (int32_t i : attribute->pad())
Kevin Cheng9fe17242021-11-10 01:04:39 +0000144 {
145 if (i < 0)
146 {
147 msg = "At least one pad is smaller than zero";
148 return 1;
149 }
150 }
151
152 for (int32_t i : attribute->stride())
153 {
154 if (i < 1)
155 {
156 msg = "At least one stride dimension is smaller than one";
157 return 1;
158 }
159 }
160
161 for (int32_t i : attribute->dilation())
162 {
163 if (i < 1)
164 {
165 msg = "At least one dilation dimension is smaller than one";
166 return 1;
167 }
168 }
169
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100170 ASSERT_MSG(conv_dimension == 2 || conv_dimension == 3, "Unsupported convolution dimension")
171
TatWai Chongfd629052022-07-25 04:01:58 +0000172 int32_t offset_d = conv_dimension == 3 ? 1 : 0;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000173 int32_t ID = conv_dimension == 3 ? input_shape[1] : 1;
174 int32_t IH = input_shape[1 + offset_d];
175 int32_t IW = input_shape[2 + offset_d];
176 int32_t OD = conv_dimension == 3 ? output_shape[1] : 1;
177 int32_t OH = output_shape[1 + offset_d];
178 int32_t OW = output_shape[2 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100179
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000180 int32_t stride_d = conv_dimension == 3 ? attribute->stride()[0] : 1;
181 int32_t stride_y = attribute->stride()[0 + offset_d];
182 int32_t stride_x = attribute->stride()[1 + offset_d];
183 int32_t kernel_d = conv_dimension == 3 ? weights[offset_kernel] : 1;
184 int32_t kernel_h = weights[offset_kernel + offset_d];
185 int32_t kernel_w = weights[offset_kernel + 1 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100186 int32_t dilation_d = conv_dimension == 3 ? attribute->dilation()[0] : 1;
187 int32_t dilation_y = attribute->dilation()[0 + offset_d];
188 int32_t dilation_x = attribute->dilation()[1 + offset_d];
189
190 offset_d *= 2;
TatWai Chong86c403b2022-06-06 20:46:01 -0700191 int32_t pad_d0 = conv_dimension == 3 ? attribute->pad()[0] : 0;
192 int32_t pad_d1 = conv_dimension == 3 ? attribute->pad()[1] : 0;
193 int32_t pad_top = attribute->pad()[0 + offset_d];
194 int32_t pad_bottom = attribute->pad()[1 + offset_d];
195 int32_t pad_left = attribute->pad()[2 + offset_d];
196 int32_t pad_right = attribute->pad()[3 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100197
198 int32_t full_D = ID - 1 + pad_d0 + pad_d1 - (kernel_d - 1) * dilation_d;
199 int32_t full_H = IH - 1 + pad_top + pad_bottom - (kernel_h - 1) * dilation_y;
200 int32_t full_W = IW - 1 + pad_left + pad_right - (kernel_w - 1) * dilation_x;
201
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000202 if ((full_H % stride_y != 0) || (full_W % stride_x != 0) || (full_D % stride_d != 0))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100203 {
204 msg = "Parameters must yield exact integer output dimensions";
205 return 1;
206 }
207
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000208 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1) || (OD != (full_D / stride_d) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100209 {
210 std::string msg_d = "";
211 if (conv_dimension == 3)
212 {
213 msg_d += std::to_string((full_D / stride_d) + 1) + ",";
214 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000215 msg = "Mismatch between output shape provided and expected output shape (" + msg_d +
216 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100217 return 1;
218 }
219
Tai Lya4d748b2023-03-28 22:06:56 +0000220 if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0)
221 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000222 msg = "Input zero point must be zero for non-int8 data";
223 return 1;
224 }
Tai Lya4d748b2023-03-28 22:06:56 +0000225 if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0)
226 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 msg = "Weight zero point must be zero for non-int8 data";
228 return 1;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000229 }
230
231 return 0;
232}
233
Luke Hutton57287132023-02-06 14:54:18 +0000234int check_fft_shape(const std::vector<int32_t>& in_real,
235 const std::vector<int32_t>& in_imag,
236 const std::vector<int32_t>& out_real,
237 const std::vector<int32_t>& out_imag,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000238 std::string& msg)
239{
240 const bool is_rfft = in_imag.empty();
241 auto is_power_of_two = [](int32_t n) -> bool { return (n & (n - 1)) == 0 && n > 0; };
Luke Hutton57287132023-02-06 14:54:18 +0000242
243 if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2]))
244 {
245 msg = "Input height and width must be a power of two";
246 return 1;
247 }
248
249 // RFFT does not have a second input
250 if (!is_rfft)
251 {
252 bool input_check = true;
253 for (size_t i = 0; i < in_real.size(); i++)
254 {
255 if (in_real[i] != in_imag[i])
256 {
257 input_check = false;
258 break;
259 }
260 }
261 if (!input_check)
262 {
263 msg = "Mismatch between real input shape and imaginary input shape";
264 return 1;
265 }
266 }
267
268 bool output_check = true;
269 for (size_t i = 0; i < out_real.size(); i++)
270 {
271 if (out_real[i] != out_imag[i])
272 {
273 output_check = false;
274 break;
275 }
276 }
277 if (!output_check)
278 {
279 msg = "Mismatch between real output shape and imaginary output shape";
280 return 1;
281 }
282
283 if (in_real[0] != out_real[0])
284 {
285 msg = "Input and output batch size don't match";
286 return 1;
287 }
288 if (in_real[1] != out_real[1])
289 {
290 msg = "Input and output height don't match";
291 return 1;
292 }
293
294 if (is_rfft)
295 {
296 if (in_real[2] / 2 + 1 != out_real[2])
297 {
298 msg = "Output width is expected to match input width / 2 + 1";
299 return 1;
300 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000301 }
302 else
303 {
Luke Hutton57287132023-02-06 14:54:18 +0000304 if (in_real[2] != out_real[2])
305 {
306 msg = "Input and output width don't match";
307 return 1;
308 }
309 }
310
311 return 0;
312}
313
Tai Lya4d748b2023-03-28 22:06:56 +0000314template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000315OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700316 : GraphNode(sgt_, Op_ARGMAX, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700317{
318 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000319 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
321 INIT_ATTRIBUTE(Axis);
322}
323
Tai Lya4d748b2023-03-28 22:06:56 +0000324template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700325OpArgMax<Rank, Dtype>::~OpArgMax()
326{
327 if (attribute)
328 delete attribute;
329}
330
Tai Lya4d748b2023-03-28 22:06:56 +0000331template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700332int OpArgMax<Rank, Dtype>::checkTensorAttributes()
333{
334 if (validateRequiredOperands())
335 return 1;
336
Kevin Chengcc61be32021-10-14 17:09:57 -0700337 if (validateRequiredRank(inputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700338 {
339 return 1;
340 }
341
Kevin Chengcc61be32021-10-14 17:09:57 -0700342 int32_t output_rank = inputs[0]->getRank() - 1;
343 if (output_rank != outputs[0]->getRank())
344 {
345 printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
346 return 1;
347 }
348
Tai Lya4d748b2023-03-28 22:06:56 +0000349 if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32)
Kevin Chengcc61be32021-10-14 17:09:57 -0700350 {
351 printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
352 return 1;
353 }
354
Eric Kunzee5e26762020-10-13 16:11:07 -0700355 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
356 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
357
Kevin Chengcc61be32021-10-14 17:09:57 -0700358 if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
359 {
360 printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
361 return 1;
362 }
363
364 bool shape_check = true;
365 for (int32_t i = 0; i < input->getRank(); i++)
366 {
367 if (i < attribute->axis())
368 {
369 if (input->getShape()[i] != output->getShape()[i])
370 {
371 shape_check = false;
372 break;
373 }
374 }
375 else if (i > attribute->axis())
376 {
377 if (input->getShape()[i] != output->getShape()[i - 1])
378 {
379 shape_check = false;
380 break;
381 }
382 }
383 // No need to check i == axis
384 }
385 if (!shape_check)
386 {
387 printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
388 return 1;
389 }
390
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 return 0;
392}
393
Tai Lya4d748b2023-03-28 22:06:56 +0000394template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700395int OpArgMax<Rank, Dtype>::eval()
396{
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000397 // Check Tosa Level
398 auto tosa_level = g_func_config.tosa_level;
399 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
400
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
402
403 this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
404
405 return GraphNode::eval();
406}
407
Tai Lya4d748b2023-03-28 22:06:56 +0000408template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000409OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700410 : GraphNode(sgt_, Op_AVG_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700411{
412 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000413 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700414
Kevin Cheng93a16282021-08-31 16:14:03 -0700415 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -0700416}
417
Tai Lya4d748b2023-03-28 22:06:56 +0000418template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100419OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700420{
421 if (attribute)
422 delete attribute;
423}
424
Tai Lya4d748b2023-03-28 22:06:56 +0000425template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100426int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700427{
428 if (validateRequiredOperands())
429 return 1;
430
431 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
432 {
433 return 1;
434 }
435
436 if (inputs[0]->matchType(*outputs[0]))
437 {
438 printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
439 return 1;
440 }
441
442 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
443 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
444
Tai Lya4d748b2023-03-28 22:06:56 +0000445 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
446 "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
447 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0,
448 "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
Eric Kunzee5e26762020-10-13 16:11:07 -0700449
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000450 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000451 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700452 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000453 msg = "OpAvgPool2d: " + msg;
454 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return 1;
456 }
457
458 return 0;
459}
460
Eric Kunze830add42022-01-25 22:56:46 -0800461// This calculates the number of padding elements used for each location along an axis
462// Average pooling only divides by the number of elements used, not including padding.
463// This function uses left/right, but is also used for vertical padding with top/bottom
Tai Lya4d748b2023-03-28 22:06:56 +0000464template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
465ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(
466 int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
Eric Kunzee5e26762020-10-13 16:11:07 -0700467{
468 ETensor1<int32_t> result(out_size);
469
Eric Kunzee5e26762020-10-13 16:11:07 -0700470 result.setConstant(kernel_size);
471
Eric Kunze830add42022-01-25 22:56:46 -0800472 // adjust divisors on the left side for padding
473 // We start at the leftmost output element, and remove pad_left - (index * stride) elements
474 // until we have no more padding being used
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000475 for (int index = 0; (index <= pad_left / stride) && (index < out_size); index++)
476 {
Eric Kunze830add42022-01-25 22:56:46 -0800477 int32_t adjust = pad_left - (index * stride);
478 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 }
480
Eric Kunze830add42022-01-25 22:56:46 -0800481 // The process repeats on the right side. Padding starts taking effect as we
482 // near the rightmost input element. The first output element which touches
483 // padding is defined in the initialization of index below. Then we keep moving
484 // to the right, increasing padding until we get to the last output element.
485 int index = std::max(0, ((pad_left + in_size - kernel_size) / stride) + 1);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000486 for (; index < out_size; index++)
487 {
Eric Kunze830add42022-01-25 22:56:46 -0800488 int32_t adjust = ((index * stride) + kernel_size) - (pad_left + in_size);
489 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700490 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700491 return result;
492}
493
494// assuming input and output tensor have same scales like tflite reference
495// so no need to scale input and output
Tai Lya4d748b2023-03-28 22:06:56 +0000496template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100497int OpAvgPool2d<Dtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700498{
499 int in_batch = this->in->getShape()[0];
500 int in_height = this->in->getShape()[1];
501 int in_width = this->in->getShape()[2];
502 int in_channels = this->in->getShape()[3];
503
504 int out_batch = this->out->getShape()[0];
505 int out_height = this->out->getShape()[1];
506 int out_width = this->out->getShape()[2];
507 int out_channels = this->out->getShape()[3];
508
Kevin Chengacb550f2021-06-29 15:32:19 -0700509 ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
510 ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700511
TatWai Chong86c403b2022-06-06 20:46:01 -0700512 int pad_top = this->attribute->pad()[0];
513 int pad_bottom = this->attribute->pad()[1];
514 int pad_left = this->attribute->pad()[2];
515 int pad_right = this->attribute->pad()[3];
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000516 int kernel_y = this->attribute->kernel()[0];
517 int kernel_x = this->attribute->kernel()[1];
518 int stride_y = this->attribute->stride()[0];
519 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000520
521 // Check Tosa Level
522 auto tosa_level = g_func_config.tosa_level;
523 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
524 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
525 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
526 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
527 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
528 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
529 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
530 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -0700531
Tai Lya4d748b2023-03-28 22:06:56 +0000532 TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype());
James Ward8b390432022-08-12 20:48:56 +0100533
Eric Kunzee5e26762020-10-13 16:11:07 -0700534 DEBUG_INFO(OP,
535 "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
James Ward8b390432022-08-12 20:48:56 +0100536 "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
Jerry Gea793f462023-04-11 00:05:02 +0000537 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
538 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
540 Eigen::array<Eigen::Index, 2> im2col_input_dims;
Jerry Gea793f462023-04-11 00:05:02 +0000541 im2col_input_dims[0] = kernel_y * kernel_x;
Eric Kunzee5e26762020-10-13 16:11:07 -0700542 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
543
544 Eigen::array<Eigen::Index, 4> col2im_output_dims;
545 col2im_output_dims[0] = out_batch;
546 col2im_output_dims[1] = out_height;
547 col2im_output_dims[2] = out_width;
548 col2im_output_dims[3] = out_channels;
549
TatWai Chong86c403b2022-06-06 20:46:01 -0700550 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
551 pad[0] = std::make_pair(0, 0);
552 pad[1] = std::make_pair(pad_top, pad_bottom);
553 pad[2] = std::make_pair(pad_left, pad_right);
554 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700555
556 ETensor4<InEigenType> input_val = this->in->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000557 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700558 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000559 input_val = input_val - (InEigenType)attribute->input_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700560 }
561
TatWai Chong86c403b2022-06-06 20:46:01 -0700562 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700563
Tai Ly307392a2023-05-12 21:42:19 +0000564 if (g_func_config.abs_mode)
565 {
566 // in abs_mode: take abs values of input_padded
567 input_padded = input_padded.abs();
568 }
569
Eric Kunzee5e26762020-10-13 16:11:07 -0700570 // assuming input and output have same scales
571 // so input and output scaling is not required
572 // TODO: check if this assumption TOSA made
573
574 // extract_image_patches() output [N, KH, KW, H * W, C]
575 // transpose to [KH, KW, N, H * W, C]
576 // reshape to [KH * KW, N * H * W * C]
577 ETensor2<InEigenType> input_extract_patches =
Jerry Gea793f462023-04-11 00:05:02 +0000578 input_padded.extract_image_patches(kernel_y, kernel_x, stride_y, stride_x, 1, 1, Eigen::PADDING_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700579 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
580 .reshape(im2col_input_dims);
581
582 // 1D result with [N * H * W * C]
583 ETensor1<AccEigenType> out_1d(this->out->getElementCount());
584 out_1d.setZero();
585
586 // sum pool
587 for (size_t i = 0; i < this->out->getElementCount(); i++)
588 {
Jerry Gea793f462023-04-11 00:05:02 +0000589 for (int32_t j = 0; j < kernel_y * kernel_x; j++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700590 {
591 out_1d(i) += (AccEigenType)input_extract_patches(j, i);
592 }
593 }
594
595 // reshape result to [N, H, W, C] and divide with div_map
596 ETensor4<AccEigenType> sum = out_1d.reshape(col2im_output_dims);
597
598 // calculate 1d height/width div_map (number of elements this pooling window covers)
599 // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C]
Jeremy Johnson44eb88d2023-04-24 09:49:58 +0100600 ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_y, stride_y, pad_top, pad_bottom);
Jerry Gea793f462023-04-11 00:05:02 +0000601 ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_x, stride_x, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700602 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
603 Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
604
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000605 ETensor2<int32_t> dm2_w = div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width });
606 ETensor2<int32_t> dm2_h = div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 });
607 ETensor4<int32_t> div_map = dm2_h.contract(dm2_w, contract_dims)
608 .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
609 .broadcast(bcast);
Tai Lya4d748b2023-03-28 22:06:56 +0000610 if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
611 Dtype != TOSA_REF_TYPE_FP64)
Eric Kunzee5e26762020-10-13 16:11:07 -0700612 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700613 try
614 {
615 this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
616 int32_t multiplier, shift;
617 TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift);
Eric Kunzee5e26762020-10-13 16:11:07 -0700618
Kevin Chengacb550f2021-06-29 15:32:19 -0700619 return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false);
620 });
621 }
622 catch (std::string desc)
623 {
624 REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
625 }
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000626 this->out->getTensor() = this->out->getTensor() + (OutEigenType)(attribute->output_zp());
Eric Kunzee5e26762020-10-13 16:11:07 -0700627 this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
628 this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
629 }
630 else
631 {
James Ward24dbc422022-10-19 12:20:31 +0100632 // Case for float-types
Eric Kunzee5e26762020-10-13 16:11:07 -0700633 this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
634 }
635
636 return GraphNode::eval();
637}
638
Tai Lya4d748b2023-03-28 22:06:56 +0000639template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000640OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700641 : GraphNode(sgt_, Op_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700642{
643 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000644 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700645
Kevin Cheng93a16282021-08-31 16:14:03 -0700646 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700647}
648
Tai Lya4d748b2023-03-28 22:06:56 +0000649template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000650OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700651{
652 if (attribute)
653 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700654}
655
Tai Lya4d748b2023-03-28 22:06:56 +0000656template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000657int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700658{
659 if (validateRequiredOperands())
660 return 1;
661
662 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
663 {
664 return 1;
665 }
666
667 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
668 if (inputs[2]->getRank() != 1)
669 {
670 printNodeValidationError("OpConv2d: bias tensor must be rank 1");
671 }
672
James Wardd34b3fc2023-01-18 14:51:25 +0000673 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000674 "OpConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700675
Eric Kunzee5e26762020-10-13 16:11:07 -0700676 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
677 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
678 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100679 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700680
Kevin Cheng9fe17242021-11-10 01:04:39 +0000681 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000682 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000683 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700684 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000685 msg = "OpConv2d: " + msg;
686 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700687 return 1;
688 }
689
Eric Kunzee5e26762020-10-13 16:11:07 -0700690 return 0;
691}
692
Tai Lya4d748b2023-03-28 22:06:56 +0000693template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000694int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700695{
696 int in_batch = this->input->getShape()[0];
697 int in_height = this->input->getShape()[1];
698 int in_width = this->input->getShape()[2];
699 int in_channels = this->input->getShape()[3];
700
701 int f_out_channels = this->weight->getShape()[0];
702 int f_height = this->weight->getShape()[1];
703 int f_width = this->weight->getShape()[2];
704 int f_in_channels = this->weight->getShape()[3];
705
706 int b_out_channels = this->bias->getShape()[0];
707
708 int out_batch = this->output->getShape()[0];
709 int out_height = this->output->getShape()[1];
710 int out_width = this->output->getShape()[2];
711 int out_channels = this->output->getShape()[3];
712
Kevin Chengacb550f2021-06-29 15:32:19 -0700713 ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
714 ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
715 in_channels);
716 ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
717 out_channels);
Tai Lya641dd52023-08-11 19:58:50 +0000718 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv2d: bias channel mismatch %d != %d",
719 b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700720
TatWai Chong86c403b2022-06-06 20:46:01 -0700721 int pad_top = this->attribute->pad()[0];
722 int pad_bottom = this->attribute->pad()[1];
723 int pad_left = this->attribute->pad()[2];
724 int pad_right = this->attribute->pad()[3];
725
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000726 int stride_y = this->attribute->stride()[0];
727 int stride_x = this->attribute->stride()[1];
728 int dilation_y = this->attribute->dilation()[0];
729 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000730
731 // Check Tosa Level
732 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000733 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
734 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
735 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
736 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000737 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
738 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
739 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
740 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
741 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
742 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
744 DEBUG_INFO(OP,
745 "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000746 "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -0700747 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000748 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
749 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700750
751 // GEMM-conv2d, left matrix is input, right matrix is weight
752 Eigen::array<Eigen::Index, 2> im2col_input_dims;
753 im2col_input_dims[0] = out_batch * out_height * out_width;
754 im2col_input_dims[1] = f_height * f_width * f_in_channels;
755
756 Eigen::array<Eigen::Index, 2> im2col_weight_dims;
757 im2col_weight_dims[0] = f_height * f_width * f_in_channels;
758 im2col_weight_dims[1] = f_out_channels;
759
760 Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
761 bias_reshaped_dims[0] = 1;
762 bias_reshaped_dims[1] = b_out_channels;
763
764 Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
765 weight_zp_bcast_dims[0] = f_height;
766 weight_zp_bcast_dims[1] = f_width;
767 weight_zp_bcast_dims[2] = f_in_channels;
768
769 Eigen::array<Eigen::Index, 2> bias_bcast_dims;
770 bias_bcast_dims[0] = out_batch * out_height * out_width;
Tai Lya641dd52023-08-11 19:58:50 +0000771 bias_bcast_dims[1] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700772
773 Eigen::array<Eigen::Index, 4> col2im_output_dims;
774 col2im_output_dims[0] = out_batch;
775 col2im_output_dims[1] = out_height;
776 col2im_output_dims[2] = out_width;
777 col2im_output_dims[3] = out_channels;
778
779 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
780
TatWai Chong86c403b2022-06-06 20:46:01 -0700781 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
782 pad[0] = std::make_pair(0, 0);
783 pad[1] = std::make_pair(pad_top, pad_bottom);
784 pad[2] = std::make_pair(pad_left, pad_right);
785 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700786
787 TIn input_val = this->input->getTensor();
788 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000789 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700790 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000791 input_val = input_val - (InEigenType)attribute->input_zp();
792 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700793 }
794
TatWai Chong86c403b2022-06-06 20:46:01 -0700795 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700796
Tai Ly307392a2023-05-12 21:42:19 +0000797 TBias bias_val = this->bias->getTensor();
798
799 if (g_func_config.abs_mode)
800 {
801 // in abs_mode: take abs values of conv operands
802 input_padded = input_padded.abs();
803 weight_val = weight_val.abs();
804 bias_val = bias_val.abs();
805 }
806
Eric Kunzee5e26762020-10-13 16:11:07 -0700807 // extract_image_patches() output [N, KH, KW, H * W, C]
808 // need to transpose to [N, H * W, KH, KW, C]
809 ETensor5<InEigenType> input_extract_patches =
810 input_padded
Jerry Gea793f462023-04-11 00:05:02 +0000811 .extract_image_patches(f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700812 .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
813
814 // reshape input to [N * H * W, KH * KW * C]
815 ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
816
817 // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
818 ETensor2<WeightEigenType> im2col_weight =
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000819 weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
Eric Kunzee5e26762020-10-13 16:11:07 -0700820
821 // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
822 // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
Tai Ly307392a2023-05-12 21:42:19 +0000823 ETensor2<OutEigenType> bias_2d =
824 (bias_val.reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700825
826 // output matrix is [N * H * W, C]
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000827 ETensor2<OutEigenType> contracted_result = (im2col_input.template cast<AccEigenType>().contract(
828 im2col_weight.template cast<AccEigenType>(), contract_dims))
829 .template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700830
831 // adding bias
James Ward8b390432022-08-12 20:48:56 +0100832 ETensor2<OutEigenType> biased_output = contracted_result + bias_2d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700833
834 // reshape back to [N, H, W, C]
835 this->output->getTensor() = biased_output.reshape(col2im_output_dims);
836
Tai Lya4d748b2023-03-28 22:06:56 +0000837 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -0700838 {
James Ward8b390432022-08-12 20:48:56 +0100839 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
840 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700841 }
842
843 return GraphNode::eval();
844}
845
Tai Lya4d748b2023-03-28 22:06:56 +0000846template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000847OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Cheng1533b852021-09-01 12:51:58 -0700848 : GraphNode(sgt_, Op_CONV3D, id_)
849{
850 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000851 setRequiredRank(5, 5);
Kevin Cheng1533b852021-09-01 12:51:58 -0700852
853 INIT_ATTRIBUTE(Conv);
Kevin Cheng1533b852021-09-01 12:51:58 -0700854}
855
Tai Lya4d748b2023-03-28 22:06:56 +0000856template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000857OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
Kevin Cheng1533b852021-09-01 12:51:58 -0700858{
859 if (attribute)
860 delete attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700861}
862
Tai Lya4d748b2023-03-28 22:06:56 +0000863template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000864int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Kevin Cheng1533b852021-09-01 12:51:58 -0700865{
866 if (validateRequiredOperands())
867 return 1;
868
869 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
870 {
871 return 1;
872 }
873
874 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
875 if (inputs[2]->getRank() != 1)
876 {
877 printNodeValidationError("OpConv3d: bias tensor must be rank 1");
878 }
879
James Wardd34b3fc2023-01-18 14:51:25 +0000880 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000881 "OpConv3d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700882
Kevin Cheng1533b852021-09-01 12:51:58 -0700883 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
884 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
885 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100886 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng1533b852021-09-01 12:51:58 -0700887
Kevin Cheng9fe17242021-11-10 01:04:39 +0000888 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000889 if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000890 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Kevin Cheng1533b852021-09-01 12:51:58 -0700891 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000892 msg = "OpConv3d: " + msg;
893 printNodeValidationError(msg.c_str());
Kevin Cheng1533b852021-09-01 12:51:58 -0700894 return 1;
895 }
896
Kevin Cheng1533b852021-09-01 12:51:58 -0700897 return 0;
898}
899
Tai Lya4d748b2023-03-28 22:06:56 +0000900template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000901int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
Kevin Cheng1533b852021-09-01 12:51:58 -0700902{
903 int in_batch = this->input->getShape()[0];
904 int in_depth = this->input->getShape()[1];
905 int in_height = this->input->getShape()[2];
906 int in_width = this->input->getShape()[3];
907 int in_channels = this->input->getShape()[4];
908
909 int f_out_channels = this->weight->getShape()[0];
910 int f_depth = this->weight->getShape()[1];
911 int f_height = this->weight->getShape()[2];
912 int f_width = this->weight->getShape()[3];
913 int f_in_channels = this->weight->getShape()[4];
914
915 int b_out_channels = this->bias->getShape()[0];
916
917 int out_batch = this->output->getShape()[0];
918 int out_depth = this->output->getShape()[1];
919 int out_height = this->output->getShape()[2];
920 int out_width = this->output->getShape()[3];
921 int out_channels = this->output->getShape()[4];
922
923 ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
924 ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
925 in_channels);
926 ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
927 out_channels);
Tai Lya641dd52023-08-11 19:58:50 +0000928 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv3d: bias channel mismatch %d != %d",
929 b_out_channels, out_channels);
Kevin Cheng1533b852021-09-01 12:51:58 -0700930
TatWai Chong86c403b2022-06-06 20:46:01 -0700931 int pad_d0 = this->attribute->pad()[0];
932 int pad_d1 = this->attribute->pad()[1];
933 int pad_top = this->attribute->pad()[2];
934 int pad_bottom = this->attribute->pad()[3];
935 int pad_left = this->attribute->pad()[4];
936 int pad_right = this->attribute->pad()[5];
937
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000938 int stride_d = this->attribute->stride()[0];
939 int stride_y = this->attribute->stride()[1];
940 int stride_x = this->attribute->stride()[2];
TatWai Chong86c403b2022-06-06 20:46:01 -0700941
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000942 int dilation_d = this->attribute->dilation()[0];
943 int dilation_y = this->attribute->dilation()[1];
944 int dilation_x = this->attribute->dilation()[2];
Jerry Gea793f462023-04-11 00:05:02 +0000945
946 // Check Tosa Level
947 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000948 LEVEL_CHECK(dilation_d * f_depth <= tosa_level.MAX_KERNEL,
949 "dilation_d * KD should be smaller than or equal to MAX_KERNEL");
950 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
951 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
952 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
953 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000954 LEVEL_CHECK(pad_d0 <= tosa_level.MAX_KERNEL, "pad_d0 should be smaller than or equal to MAX_KERNEL");
955 LEVEL_CHECK(pad_d1 <= tosa_level.MAX_KERNEL, "pad_d1 should be smaller than or equal to MAX_KERNEL");
956 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
957 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
958 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
959 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
960 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
961 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
962 LEVEL_CHECK(stride_d <= tosa_level.MAX_STRIDE, "stride_d should be smaller than or equal to MAX_STRIDE");
Kevin Cheng1533b852021-09-01 12:51:58 -0700963
964 DEBUG_INFO(
965 OP,
966 "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000967 "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
Kevin Cheng1533b852021-09-01 12:51:58 -0700968 in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
Jerry Gea793f462023-04-11 00:05:02 +0000969 out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_y, stride_x, dilation_d, dilation_y,
970 dilation_x, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
Kevin Cheng1533b852021-09-01 12:51:58 -0700971
TatWai Chong86c403b2022-06-06 20:46:01 -0700972 Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
973 pad[0] = std::make_pair(0, 0);
974 pad[1] = std::make_pair(pad_d0, pad_d1);
975 pad[2] = std::make_pair(pad_top, pad_bottom);
976 pad[3] = std::make_pair(pad_left, pad_right);
977 pad[4] = std::make_pair(0, 0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700978
979 TIn input_val = this->input->getTensor();
980 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000981 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Kevin Cheng1533b852021-09-01 12:51:58 -0700982 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000983 input_val = input_val - (InEigenType)attribute->input_zp();
984 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Kevin Cheng1533b852021-09-01 12:51:58 -0700985 }
986
TatWai Chong86c403b2022-06-06 20:46:01 -0700987 ETensor5<InEigenType> input_padded = input_val.pad(pad);
Kevin Cheng1533b852021-09-01 12:51:58 -0700988
Tai Ly307392a2023-05-12 21:42:19 +0000989 TBias bias_val = this->bias->getTensor();
990
991 if (g_func_config.abs_mode)
992 {
993 // in abs_mode: take abs values of conv operands
994 input_padded = input_padded.abs();
995 weight_val = weight_val.abs();
996 bias_val = bias_val.abs();
997 }
998
Kevin Cheng1533b852021-09-01 12:51:58 -0700999 // 1. initialize with bias
1000 Eigen::array<Eigen::Index, 5> reshape_dim;
1001 reshape_dim.fill(1);
1002 reshape_dim[4] = b_out_channels;
1003
1004 Eigen::array<Eigen::Index, 5> bcast;
1005 bcast[0] = out_batch;
1006 bcast[1] = out_depth;
1007 bcast[2] = out_height;
1008 bcast[3] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +00001009 bcast[4] = (b_out_channels == 1) ? out_channels : 1;
Tai Ly307392a2023-05-12 21:42:19 +00001010 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Kevin Cheng1533b852021-09-01 12:51:58 -07001011
1012 // 2. direct convolution
James Ward8b390432022-08-12 20:48:56 +01001013 AccEigenType acc(0.0);
Kevin Cheng1533b852021-09-01 12:51:58 -07001014 int d_idx, h_idx, w_idx;
1015
1016 for (int ob = 0; ob < out_batch; ob++)
1017 {
1018 for (int od = 0; od < out_depth; od++)
1019 {
1020 for (int oh = 0; oh < out_height; oh++)
1021 {
1022 for (int ow = 0; ow < out_width; ow++)
1023 {
1024 for (int oc = 0; oc < out_channels; oc++)
1025 {
Eric Kunze7edb34c2022-05-16 17:34:40 -07001026 // Initialize accumulator with bias value
James Ward8b390432022-08-12 20:48:56 +01001027 acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc);
Kevin Cheng1533b852021-09-01 12:51:58 -07001028 for (int fd = 0; fd < f_depth; fd++)
1029 {
1030 d_idx = od * stride_d + fd * dilation_d;
1031 for (int fh = 0; fh < f_height; fh++)
1032 {
Jerry Gea793f462023-04-11 00:05:02 +00001033 h_idx = oh * stride_y + fh * dilation_y;
Kevin Cheng1533b852021-09-01 12:51:58 -07001034 for (int fw = 0; fw < f_width; fw++)
1035 {
Jerry Gea793f462023-04-11 00:05:02 +00001036 w_idx = ow * stride_x + fw * dilation_x;
Kevin Cheng1533b852021-09-01 12:51:58 -07001037 for (int ic = 0; ic < in_channels; ic++)
1038 {
1039 acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
1040 (AccEigenType)weight_val(oc, fd, fh, fw, ic));
1041 }
1042 }
1043 }
1044 }
James Ward8b390432022-08-12 20:48:56 +01001045 this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc;
Kevin Cheng1533b852021-09-01 12:51:58 -07001046 }
1047 }
1048 }
1049 }
1050 }
1051
Tai Lya4d748b2023-03-28 22:06:56 +00001052 if (OutDtype == TOSA_REF_TYPE_INT48)
Kevin Cheng1533b852021-09-01 12:51:58 -07001053 {
James Ward8b390432022-08-12 20:48:56 +01001054 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1055 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Kevin Cheng1533b852021-09-01 12:51:58 -07001056 }
1057
1058 return GraphNode::eval();
1059}
1060
Tai Lya4d748b2023-03-28 22:06:56 +00001061template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001062OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001063 TosaAttributeBase* attribute_,
1064 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001065 : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001066{
1067 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001068 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001069
Kevin Cheng93a16282021-08-31 16:14:03 -07001070 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001071}
1072
Tai Lya4d748b2023-03-28 22:06:56 +00001073template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001074OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001075{
1076 if (attribute)
1077 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001078}
1079
Tai Lya4d748b2023-03-28 22:06:56 +00001080template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001081int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001082{
1083 if (validateRequiredOperands())
1084 return 1;
1085
1086 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1087 {
1088 return 1;
1089 }
1090
1091 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
1092 if (inputs[2]->getRank() != 1)
1093 {
1094 printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
1095 }
1096
James Wardd34b3fc2023-01-18 14:51:25 +00001097 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001098 "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001099
Eric Kunzee5e26762020-10-13 16:11:07 -07001100 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1101 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1102 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001103 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001104
Kevin Cheng9fe17242021-11-10 01:04:39 +00001105 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001106 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001107 weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001108 {
Kevin Cheng9fe17242021-11-10 01:04:39 +00001109 msg = "OpDepthwiseConv2d: " + msg;
1110 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001111 return 1;
1112 }
1113
Eric Kunzee5e26762020-10-13 16:11:07 -07001114 return 0;
1115}
1116
Tai Lya4d748b2023-03-28 22:06:56 +00001117template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001118int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001119{
1120 int in_batch = this->input->getShape()[0];
1121 int in_height = this->input->getShape()[1];
1122 int in_width = this->input->getShape()[2];
1123 int in_channels = this->input->getShape()[3];
1124
1125 int f_height = this->weight->getShape()[0];
1126 int f_width = this->weight->getShape()[1];
1127 int f_in_channels = this->weight->getShape()[2];
1128 int f_multiplier = this->weight->getShape()[3];
1129
1130 int b_out_channels = this->bias->getShape()[0];
1131
1132 int out_batch = this->output->getShape()[0];
1133 int out_height = this->output->getShape()[1];
1134 int out_width = this->output->getShape()[2];
1135 int out_channels = this->output->getShape()[3];
1136
Kevin Chengacb550f2021-06-29 15:32:19 -07001137 ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1138 ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1139 in_channels);
1140 ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
1141 in_channels * f_multiplier, out_channels);
Tai Lya641dd52023-08-11 19:58:50 +00001142 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
1143 "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001144
TatWai Chong86c403b2022-06-06 20:46:01 -07001145 int pad_top = this->attribute->pad()[0];
1146 int pad_bottom = this->attribute->pad()[1];
1147 int pad_left = this->attribute->pad()[2];
1148 int pad_right = this->attribute->pad()[3];
1149
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001150 int stride_y = this->attribute->stride()[0];
1151 int stride_x = this->attribute->stride()[1];
1152 int dilation_y = this->attribute->dilation()[0];
1153 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001154
1155 // Check Tosa Level
1156 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001157 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
1158 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
1159 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
1160 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00001161 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1162 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1163 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1164 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
1165 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1166 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -07001167
1168 DEBUG_INFO(OP,
1169 "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00001170 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -07001171 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001172 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
1173 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001174
TatWai Chong86c403b2022-06-06 20:46:01 -07001175 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1176 pad[0] = std::make_pair(0, 0);
1177 pad[1] = std::make_pair(pad_top, pad_bottom);
1178 pad[2] = std::make_pair(pad_left, pad_right);
1179 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001180
1181 TIn input_val = this->input->getTensor();
1182 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001183 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001184 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001185 input_val = input_val - (InEigenType)attribute->input_zp();
1186 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001187 }
1188
TatWai Chong86c403b2022-06-06 20:46:01 -07001189 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -07001190
Tai Ly307392a2023-05-12 21:42:19 +00001191 TBias bias_val = this->bias->getTensor();
1192
1193 if (g_func_config.abs_mode)
1194 {
1195 // in abs_mode: take abs values of conv operands
1196 input_padded = input_padded.abs();
1197 weight_val = weight_val.abs();
1198 bias_val = bias_val.abs();
1199 }
1200
Eric Kunzee5e26762020-10-13 16:11:07 -07001201 // GEMM doesn't fit well with DepthwiseConv2d
TatWai Chong86c403b2022-06-06 20:46:01 -07001202 // 1. use extract_image_patches() to handle stride/dilation/pad
Eric Kunzee5e26762020-10-13 16:11:07 -07001203 // 2. perform direct convolution
1204
1205 // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
1206 ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
Jerry Gea793f462023-04-11 00:05:02 +00001207 f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID);
Eric Kunzee5e26762020-10-13 16:11:07 -07001208
1209 Eigen::array<Eigen::Index, 4> reshape_dim;
1210 reshape_dim.fill(1);
1211 reshape_dim[3] = b_out_channels;
1212
1213 Eigen::array<Eigen::Index, 4> bcast;
1214 bcast[0] = out_batch;
1215 bcast[1] = out_height;
1216 bcast[2] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +00001217 bcast[3] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07001218
1219 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00001220 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001221
1222 // 2. direct depthwise convolution
1223 for (int ob = 0; ob < out_batch; ob++)
1224 {
1225 for (int oh = 0; oh < out_height; oh++)
1226 {
1227 for (int ow = 0; ow < out_width; ow++)
1228 {
1229 for (int ic = 0; ic < in_channels; ic++)
1230 {
1231 for (int cm = 0; cm < f_multiplier; cm++)
1232 {
1233 for (int fh = 0; fh < f_height; fh++)
1234 {
1235 for (int fw = 0; fw < f_width; fw++)
1236 {
James Ward8b390432022-08-12 20:48:56 +01001237 // Perform multiplication in AccEigenType then cast to OutEigenType
Eric Kunzebe2e87c2023-08-07 15:16:18 +00001238 this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
1239 (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh,
1240 ic) *
1241 (AccEigenType)weight_val(fh, fw, ic, cm));
Eric Kunzee5e26762020-10-13 16:11:07 -07001242 }
1243 }
1244 }
1245 }
1246 }
1247 }
1248 }
1249
Tai Lya4d748b2023-03-28 22:06:56 +00001250 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001251 {
James Ward8b390432022-08-12 20:48:56 +01001252 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1253 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001254 }
1255
1256 return GraphNode::eval();
1257}
1258
Tai Lya4d748b2023-03-28 22:06:56 +00001259template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001260OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001261 TosaAttributeBase* attribute_,
1262 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001263 : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001264{
1265 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001266 setRequiredRank(2, 2);
Eric Kunzee5e26762020-10-13 16:11:07 -07001267
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001268 INIT_ATTRIBUTE(FullyConnected);
Eric Kunzee5e26762020-10-13 16:11:07 -07001269}
1270
Tai Lya4d748b2023-03-28 22:06:56 +00001271template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001272OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
Eric Kunzee5e26762020-10-13 16:11:07 -07001273{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001274 if (attribute)
1275 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001276}
1277
Tai Lya4d748b2023-03-28 22:06:56 +00001278template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001279int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001280{
1281 if (validateRequiredOperands())
1282 return 1;
1283
1284 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1285 {
1286 return 1;
1287 }
1288
1289 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1290 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1291 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
1292
1293 if (input->getShape()[1] != weight->getShape()[1])
1294 {
1295 printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
1296 return 1;
1297 }
1298
1299 if (weight->getShape()[0] != bias->getShape()[0])
1300 {
1301 printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
1302 return 1;
1303 }
1304
James Wardd34b3fc2023-01-18 14:51:25 +00001305 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001306 "OpFullyConnected: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001307
James Ward8b390432022-08-12 20:48:56 +01001308 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001309
Tai Lya4d748b2023-03-28 22:06:56 +00001310 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
1311 "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
1312 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
1313 "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001314
Eric Kunzee5e26762020-10-13 16:11:07 -07001315 return 0;
1316}
1317
Tai Lya4d748b2023-03-28 22:06:56 +00001318template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001319int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001320{
1321 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1322 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1323
1324 Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
1325
Tai Lya641dd52023-08-11 19:58:50 +00001326 int b_out_channels = this->bias->getShape()[0];
1327 int out_channels = this->output->getShape()[1];
1328
1329 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpFullyConnected: bias channels mismatch %d != %d",
1330 b_out_channels, out_channels);
1331
Eric Kunzee5e26762020-10-13 16:11:07 -07001332 Eigen::array<Eigen::Index, 2> bias_reshape;
1333 bias_reshape[0] = 1;
Tai Lya641dd52023-08-11 19:58:50 +00001334 bias_reshape[1] = b_out_channels;
Eric Kunzee5e26762020-10-13 16:11:07 -07001335
1336 Eigen::array<Eigen::Index, 2> bias_bcast;
1337 bias_bcast[0] = this->input->getShape()[0];
Tai Lya641dd52023-08-11 19:58:50 +00001338 bias_bcast[1] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07001339
1340 TIn input_val = this->input->getTensor();
1341 TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
Tai Lya4d748b2023-03-28 22:06:56 +00001342 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001343 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001344 input_val = input_val - (InEigenType)attribute->input_zp();
1345 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001346 }
1347
Tai Ly307392a2023-05-12 21:42:19 +00001348 TBias bias_val = this->bias->getTensor();
1349
1350 if (g_func_config.abs_mode)
1351 {
1352 // in abs_mode: take abs values of conv operands
1353 input_val = input_val.abs();
1354 weight_val = weight_val.abs();
1355 bias_val = bias_val.abs();
1356 }
1357
1358 this->output->getTensor() = input_val.template cast<AccEigenType>()
1359 .contract(weight_val.template cast<AccEigenType>(), dims)
1360 .template cast<OutEigenType>() +
1361 bias_val.reshape(bias_reshape).broadcast(bias_bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001362
Tai Lya4d748b2023-03-28 22:06:56 +00001363 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001364 {
James Ward8b390432022-08-12 20:48:56 +01001365 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1366 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001367 }
1368 return GraphNode::eval();
1369}
1370
Tai Lya4d748b2023-03-28 22:06:56 +00001371template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001372OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001373 : GraphNode(sgt_, Op_MATMUL, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001374{
1375 setRequiredOperands(2, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001376 setRequiredRank(3, 3);
Eric Kunzee5e26762020-10-13 16:11:07 -07001377
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001378 INIT_ATTRIBUTE(MatMul);
Eric Kunzee5e26762020-10-13 16:11:07 -07001379}
1380
Tai Lya4d748b2023-03-28 22:06:56 +00001381template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001382OpMatMul<Dtype, OutDtype>::~OpMatMul()
Eric Kunzee5e26762020-10-13 16:11:07 -07001383{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001384 if (attribute)
1385 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001386}
1387
Tai Lya4d748b2023-03-28 22:06:56 +00001388template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001389int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001390{
1391 if (validateRequiredOperands())
1392 return 1;
1393
1394 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1395 {
1396 return 1;
1397 }
1398
James Wardd34b3fc2023-01-18 14:51:25 +00001399 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001400 "OpMatMul: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001401
Kevin Cheng2d60f002021-06-09 14:18:32 -07001402 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1403 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
James Ward8b390432022-08-12 20:48:56 +01001404 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001405
Kevin Cheng2d60f002021-06-09 14:18:32 -07001406 ASSERT_MEM(a && b && output);
1407
1408 // a: [N, H, C]
1409 // b: [N, C, W]
1410 // c: [N, H, W]
1411
1412 // Check N
1413 if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001414 {
Kevin Cheng2d60f002021-06-09 14:18:32 -07001415 printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
Eric Kunzee5e26762020-10-13 16:11:07 -07001416 return 1;
1417 }
Kevin Cheng2d60f002021-06-09 14:18:32 -07001418 N = a->getShape()[0];
Eric Kunzee5e26762020-10-13 16:11:07 -07001419
Kevin Cheng2d60f002021-06-09 14:18:32 -07001420 // Check C
1421 if (a->getShape()[2] != b->getShape()[1])
1422 {
1423 printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
1424 return 1;
1425 }
1426 C = a->getShape()[2];
1427
1428 // Check H
1429 if (a->getShape()[1] != output->getShape()[1])
1430 {
1431 printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
1432 return 1;
1433 }
1434 H = a->getShape()[1];
1435
1436 // Check W
1437 if (b->getShape()[2] != output->getShape()[2])
1438 {
1439 printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
1440 return 1;
1441 }
1442 W = b->getShape()[2];
Eric Kunzee5e26762020-10-13 16:11:07 -07001443
Tai Lya4d748b2023-03-28 22:06:56 +00001444 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0,
1445 "OpMatMul: A zeropoint must be zero for non int8_t data");
1446 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0,
1447 "OpMatMul: B zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001448
Eric Kunzee5e26762020-10-13 16:11:07 -07001449 return 0;
1450}
1451
Tai Lya4d748b2023-03-28 22:06:56 +00001452template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001453int OpMatMul<Dtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001454{
1455 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1456 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1457
1458 TIn a_val = this->a->getTensor();
1459 TIn b_val = this->b->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001460 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001461 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001462 a_val = a_val - (InEigenType)attribute->a_zp();
1463 b_val = b_val - (InEigenType)attribute->b_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001464 }
1465
Tai Ly307392a2023-05-12 21:42:19 +00001466 if (g_func_config.abs_mode)
1467 {
1468 // in abs_mode: take abs values of matmul operands
1469 a_val = a_val.abs();
1470 b_val = b_val.abs();
1471 }
1472
Kevin Cheng2d60f002021-06-09 14:18:32 -07001473 Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
1474 Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
1475 Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
1476
1477 Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
1478 Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
1479
1480 Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
1481 Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
1482
1483 // Iterate N dimension.
1484 for (int i = 0; i < N; i++)
1485 {
1486 a_begin_array[0] = i;
1487 b_begin_array[0] = i;
1488
1489 TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
1490 TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
1491 TAccRank2 output_rank2_val =
1492 a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
James Ward8b390432022-08-12 20:48:56 +01001493 TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast<OutEigenType>();
Kevin Cheng2d60f002021-06-09 14:18:32 -07001494 if (i == 0)
1495 {
1496 this->output->getTensor() = output_rank3_val;
1497 }
1498 else
1499 {
James Ward8b390432022-08-12 20:48:56 +01001500 TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001501 this->output->getTensor() = temp;
1502 }
1503 }
Eric Kunzee5e26762020-10-13 16:11:07 -07001504
Tai Lya4d748b2023-03-28 22:06:56 +00001505 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001506 {
James Ward8b390432022-08-12 20:48:56 +01001507 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1508 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001509 }
1510
1511 return GraphNode::eval();
1512}
1513
Tai Lya4d748b2023-03-28 22:06:56 +00001514template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001515OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001516 : GraphNode(sgt_, Op_MAX_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001517{
1518 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001519 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001520
Kevin Cheng93a16282021-08-31 16:14:03 -07001521 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -07001522}
1523
Tai Lya4d748b2023-03-28 22:06:56 +00001524template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001525OpMaxPool2d<Dtype>::~OpMaxPool2d()
1526{
1527 if (attribute)
1528 delete attribute;
1529}
1530
Tai Lya4d748b2023-03-28 22:06:56 +00001531template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001532int OpMaxPool2d<Dtype>::checkTensorAttributes()
1533{
1534 if (validateRequiredOperands())
1535 return 1;
1536
1537 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
1538 {
1539 return 1;
1540 }
1541
1542 if (inputs[0]->matchType(*outputs[0]))
1543 {
1544 printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
1545 return 1;
1546 }
1547
1548 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1549 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1550
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001551 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +00001552 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001553 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001554 msg = "OpMaxPool2d: " + msg;
1555 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001556 return 1;
1557 }
1558
1559 return 0;
1560}
1561
Tai Lya4d748b2023-03-28 22:06:56 +00001562template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001563int OpMaxPool2d<Dtype>::eval()
1564{
1565 int in_batch = this->in->getShape()[0];
1566 int in_height = this->in->getShape()[1];
1567 int in_width = this->in->getShape()[2];
1568 int in_channels = this->in->getShape()[3];
1569
1570 int out_batch = this->out->getShape()[0];
1571 int out_height = this->out->getShape()[1];
1572 int out_width = this->out->getShape()[2];
1573 int out_channels = this->out->getShape()[3];
1574
Kevin Chengacb550f2021-06-29 15:32:19 -07001575 ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1576 ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001577
TatWai Chong86c403b2022-06-06 20:46:01 -07001578 int pad_top = this->attribute->pad()[0];
1579 int pad_bottom = this->attribute->pad()[1];
1580 int pad_left = this->attribute->pad()[2];
1581 int pad_right = this->attribute->pad()[3];
1582
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001583 int kernel_y = this->attribute->kernel()[0];
1584 int kernel_x = this->attribute->kernel()[1];
1585 int stride_y = this->attribute->stride()[0];
1586 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001587
1588 // Check Tosa Level
1589 auto tosa_level = g_func_config.tosa_level;
1590 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
1591 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
1592 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1593 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
1594 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1595 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1596 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1597 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -07001598
1599 DEBUG_INFO(OP,
1600 "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
TatWai Chong86c403b2022-06-06 20:46:01 -07001601 "stride=[%d,%d], pad=[%d,%d,%d,%d]",
Jerry Gea793f462023-04-11 00:05:02 +00001602 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
1603 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001604
1605 Eigen::array<Eigen::Index, 2> im2col_input_dims;
Jerry Gea793f462023-04-11 00:05:02 +00001606 im2col_input_dims[0] = kernel_y * kernel_x;
Eric Kunzee5e26762020-10-13 16:11:07 -07001607 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
1608
1609 Eigen::array<Eigen::Index, 4> col2im_output_dims;
1610 col2im_output_dims[0] = out_batch;
1611 col2im_output_dims[1] = out_height;
1612 col2im_output_dims[2] = out_width;
1613 col2im_output_dims[3] = out_channels;
1614
TatWai Chong86c403b2022-06-06 20:46:01 -07001615 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1616 pad[0] = std::make_pair(0, 0);
1617 pad[1] = std::make_pair(pad_top, pad_bottom);
1618 pad[2] = std::make_pair(pad_left, pad_right);
1619 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
TatWai Chong86c403b2022-06-06 20:46:01 -07001621 ETensor4<InEigenType> input_padded = this->in->getTensor().pad(pad, std::numeric_limits<InEigenType>::lowest());
Eric Kunzee5e26762020-10-13 16:11:07 -07001622
1623 // extract_image_patches() output [N, KH, KW, H * W, C]
1624 // transpose to [KH, KW, N, H * W, C]
1625 // reshape to [KH * KW, N * H * W * C]
1626 //
1627 // Set the padding value to be the most negative value that can be
1628 // represented by the datatype to ensure that any padding values will be equal
1629 // to or smaller than the actual maximum in the KH x KW patch.
1630 ETensor2<InEigenType> input_extract_patches =
1631 input_padded
Jerry Gea793f462023-04-11 00:05:02 +00001632 .extract_image_patches(kernel_y, kernel_x, stride_y, stride_x, 1, 1, Eigen::PADDING_VALID,
Eric Kunzee5e26762020-10-13 16:11:07 -07001633 std::numeric_limits<InEigenType>::lowest())
1634 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
1635 .reshape(im2col_input_dims);
1636
1637 // Get the maximum of the KHxHW patches along axis 0
1638 Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
1639
1640 // 1D result with [N * H * W * C]
1641 ETensor1<OutEigenType> out_1d(this->out->getElementCount());
1642
1643 // index input_patches with argmax array should give the result
1644 for (size_t i = 0; i < this->out->getElementCount(); i++)
1645 {
1646 out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
1647 }
1648
1649 // reshape result to [N, H, W, C]
1650 this->out->getTensor() = out_1d.reshape(col2im_output_dims);
1651
1652 return GraphNode::eval();
1653}
1654
Tai Lya4d748b2023-03-28 22:06:56 +00001655template <TOSA_REF_TYPE Dtype>
1656OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton57287132023-02-06 14:54:18 +00001657 : GraphNode(sgt_, Op_FFT2D, id_)
1658{
1659 setRequiredOperands(2, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001660 setRequiredRank(3, 3);
Luke Hutton57287132023-02-06 14:54:18 +00001661
1662 INIT_ATTRIBUTE(FFT);
1663}
1664
Tai Lya4d748b2023-03-28 22:06:56 +00001665template <TOSA_REF_TYPE Dtype>
1666OpFFT2d<Dtype>::~OpFFT2d()
1667{
Luke Hutton57287132023-02-06 14:54:18 +00001668 if (attribute)
1669 delete attribute;
1670}
1671
Tai Lya4d748b2023-03-28 22:06:56 +00001672template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001673int OpFFT2d<Dtype>::checkTensorAttributes()
1674{
1675 if (validateRequiredOperands())
1676 return 1;
1677
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001678 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]) ||
1679 validateRequiredRank(outputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001680 {
1681 return 1;
1682 }
1683
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001684 if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) || inputs[0]->matchType(*inputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001685 {
1686 printNodeValidationError("OpFFT2d: input and output tensor type mismatch");
1687 return 1;
1688 }
1689
1690 in_real = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1691 in_imag = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
1692 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1693 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1694
1695 ASSERT_MEM(in_real && in_imag && out_real && out_imag);
1696
1697 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001698 if (check_fft_shape(in_real->getShape(), in_imag->getShape(), out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton57287132023-02-06 14:54:18 +00001699 {
1700 msg = "OpFFT2d: " + msg;
1701 printNodeValidationError(msg.c_str());
1702 return 1;
1703 }
1704
1705 return 0;
1706}
1707
Tai Lya4d748b2023-03-28 22:06:56 +00001708template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001709int OpFFT2d<Dtype>::eval()
1710{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001711 int in_real_batch = this->in_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001712 int in_real_height = this->in_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001713 int in_real_width = this->in_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001714
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001715 int in_imag_batch = this->in_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001716 int in_imag_height = this->in_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001717 int in_imag_width = this->in_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001718
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001719 int out_real_batch = this->out_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001720 int out_real_height = this->out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001721 int out_real_width = this->out_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001722
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001723 int out_imag_batch = this->out_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001724 int out_imag_height = this->out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001725 int out_imag_width = this->out_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001726
Jerry Gea793f462023-04-11 00:05:02 +00001727 // Check Tosa Level
1728 auto tosa_level = g_func_config.tosa_level;
1729 LEVEL_CHECK(in_real_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1730 LEVEL_CHECK(in_real_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1731
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001732 DEBUG_INFO(OP, "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]",
1733 in_real_batch, in_real_height, in_real_width, in_imag_batch, in_imag_height, in_imag_width,
1734 out_real_batch, out_real_height, out_real_width, out_imag_batch, out_imag_height, out_imag_width);
Luke Hutton57287132023-02-06 14:54:18 +00001735
1736 OutEigenType sum_real, sum_imag, a, sign_val = 1.0;
1737
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001738 if (attribute->inverse())
1739 {
Luke Hutton57287132023-02-06 14:54:18 +00001740 sign_val = -1.0;
1741 }
1742
Tai Ly307392a2023-05-12 21:42:19 +00001743 TIn in_real_val = this->in_real->getTensor();
1744 TIn in_imag_val = this->in_imag->getTensor();
1745
1746 if (g_func_config.abs_mode)
1747 {
1748 // in abs_mode: take abs values of real and imag operands
1749 in_real_val = in_real_val.abs();
1750 in_imag_val = in_imag_val.abs();
1751 }
1752
Luke Hutton57287132023-02-06 14:54:18 +00001753 for (int n = 0; n < in_real_batch; n++)
1754 {
1755 for (int oy = 0; oy < out_real_height; oy++)
1756 {
1757 for (int ox = 0; ox < out_real_width; ox++)
1758 {
1759 sum_real = 0.0;
1760 sum_imag = 0.0;
1761 for (int iy = 0; iy < in_real_height; iy++)
1762 {
1763 for (int ix = 0; ix < in_real_width; ix++)
1764 {
Tai Ly307392a2023-05-12 21:42:19 +00001765 OutEigenType val_real = in_real_val(n, iy, ix);
1766 OutEigenType val_imag = in_imag_val(n, iy, ix);
Luke Hutton57287132023-02-06 14:54:18 +00001767 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001768 a = sign_val * 2 * M_PI *
1769 ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width);
Luke Hutton57287132023-02-06 14:54:18 +00001770 sum_real += val_real * cos(a) + val_imag * sin(a);
1771 sum_imag += -val_real * sin(a) + val_imag * cos(a);
1772 }
1773 }
1774 this->out_real->getTensor()(n, oy, ox) = sum_real;
1775 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1776 }
1777 }
1778 }
1779
1780 return GraphNode::eval();
1781}
1782
Tai Lya4d748b2023-03-28 22:06:56 +00001783template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001784OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton261b7b62023-01-10 14:50:31 +00001785 : GraphNode(sgt_, Op_RFFT2D, id_)
1786{
1787 setRequiredOperands(1, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001788 setRequiredRank(3, 3);
Luke Hutton261b7b62023-01-10 14:50:31 +00001789}
1790
Tai Lya4d748b2023-03-28 22:06:56 +00001791template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001792OpRFFT2d<Dtype>::~OpRFFT2d()
1793{}
Luke Hutton261b7b62023-01-10 14:50:31 +00001794
Tai Lya4d748b2023-03-28 22:06:56 +00001795template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001796int OpRFFT2d<Dtype>::checkTensorAttributes()
1797{
1798 if (validateRequiredOperands())
1799 return 1;
1800
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001801 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1]))
Luke Hutton261b7b62023-01-10 14:50:31 +00001802 {
1803 return 1;
1804 }
1805
1806 if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
1807 {
1808 printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
1809 return 1;
1810 }
1811
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001812 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
Luke Hutton261b7b62023-01-10 14:50:31 +00001813 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1814 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1815
1816 ASSERT_MEM(in && out_real && out_imag);
1817
Luke Hutton57287132023-02-06 14:54:18 +00001818 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001819 if (check_fft_shape(in->getShape(), {}, out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton261b7b62023-01-10 14:50:31 +00001820 {
Luke Hutton57287132023-02-06 14:54:18 +00001821 msg = "OpRFFT2d: " + msg;
1822 printNodeValidationError(msg.c_str());
Luke Hutton261b7b62023-01-10 14:50:31 +00001823 return 1;
1824 }
1825
1826 return 0;
1827}
1828
Tai Lya4d748b2023-03-28 22:06:56 +00001829template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001830int OpRFFT2d<Dtype>::eval()
1831{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001832 int32_t in_batch = in->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001833 int32_t in_height = in->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001834 int32_t in_width = in->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001835
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001836 int32_t out_real_batch = out_real->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001837 int32_t out_real_height = out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001838 int32_t out_real_width = out_real->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001839
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001840 int32_t out_imag_batch = out_imag->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001841 int32_t out_imag_height = out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001842 int32_t out_imag_width = out_imag->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001843
Jerry Gea793f462023-04-11 00:05:02 +00001844 // Check Tosa Level
1845 auto tosa_level = g_func_config.tosa_level;
1846 LEVEL_CHECK(in_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1847 LEVEL_CHECK(in_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1848
Luke Hutton261b7b62023-01-10 14:50:31 +00001849 DEBUG_INFO(OP,
1850 "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], "
1851 "output_imag.shape=[%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001852 in_batch, in_height, in_width, out_real_batch, out_real_height, out_real_width, out_imag_batch,
1853 out_imag_height, out_imag_width);
Luke Hutton261b7b62023-01-10 14:50:31 +00001854
1855 OutEigenType sum_real, sum_imag, a;
1856
Tai Ly307392a2023-05-12 21:42:19 +00001857 TIn in_val = this->in->getTensor();
1858
1859 if (g_func_config.abs_mode)
1860 {
1861 // in abs_mode: take abs values of in operand
1862 in_val = in_val.abs();
1863 }
1864
Luke Hutton261b7b62023-01-10 14:50:31 +00001865 for (int n = 0; n < in_batch; n++)
1866 {
1867 for (int oy = 0; oy < out_real_height; oy++)
1868 {
1869 for (int ox = 0; ox < out_real_width; ox++)
1870 {
1871 sum_real = 0.0;
1872 sum_imag = 0.0;
1873 for (int iy = 0; iy < in_height; iy++)
1874 {
1875 for (int ix = 0; ix < in_width; ix++)
1876 {
1877 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
1878 a = 2 * M_PI * ((iy * (OutEigenType)oy) / in_height + (ix * (OutEigenType)ox) / in_width);
Tai Ly307392a2023-05-12 21:42:19 +00001879 sum_real += in_val(n, iy, ix) * cos(a);
1880 sum_imag += -in_val(n, iy, ix) * sin(a);
Luke Hutton261b7b62023-01-10 14:50:31 +00001881 }
1882 }
1883 this->out_real->getTensor()(n, oy, ox) = sum_real;
1884 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1885 }
1886 }
1887 }
1888
1889 return GraphNode::eval();
1890}
1891
Tai Lya4d748b2023-03-28 22:06:56 +00001892template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001893OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001894 TosaAttributeBase* attribute_,
1895 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001896 : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001897{
1898 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001899 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001900
Kevin Cheng93a16282021-08-31 16:14:03 -07001901 INIT_ATTRIBUTE(TransposeConv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001902}
1903
Tai Lya4d748b2023-03-28 22:06:56 +00001904template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001905OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001906{
1907 if (attribute)
1908 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001909}
1910
Tai Lya4d748b2023-03-28 22:06:56 +00001911template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001912int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001913{
1914 if (validateRequiredOperands())
1915 return 1;
1916
1917 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1918 {
1919 return 1;
1920 }
1921
James Wardd34b3fc2023-01-18 14:51:25 +00001922 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001923 "OpTransposeConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001924
Eric Kunzee5e26762020-10-13 16:11:07 -07001925 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1926 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1927 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001928 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001929
TatWai Chong24594f52022-06-08 00:48:04 -07001930 if (attribute->out_pad().size() != 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001931 {
TatWai Chong24594f52022-06-08 00:48:04 -07001932 printNodeValidationError("OpTransposeConv2d: illegal size for attribute out_pad");
Eric Kunzee5e26762020-10-13 16:11:07 -07001933 return 1;
1934 }
1935
1936 if (attribute->stride().size() != 2)
1937 {
1938 printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
1939 return 1;
1940 }
1941
Eric Kunzee5e26762020-10-13 16:11:07 -07001942 if (attribute->output_shape().size() != 4)
1943 {
1944 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1945 return 1;
1946 }
1947
Kevin Cheng9fe17242021-11-10 01:04:39 +00001948 for (int32_t i : attribute->stride())
1949 {
1950 if (i < 1)
1951 {
1952 printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one");
1953 return 1;
1954 }
1955 }
1956
Eric Kunzee5e26762020-10-13 16:11:07 -07001957 for (int d = 0; d < 4; d++)
1958 {
1959 if (attribute->output_shape()[d] != this->output->getShape()[d])
1960 {
1961 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1962 return 1;
1963 }
1964 }
1965
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001966 int32_t IH = input->getShape()[1];
1967 int32_t IW = input->getShape()[2];
1968 int32_t OH = output->getShape()[1];
1969 int32_t OW = output->getShape()[2];
1970
1971 int32_t stride_y = attribute->stride()[0];
1972 int32_t stride_x = attribute->stride()[1];
1973 int32_t kernel_h = weight->getShape()[1];
1974 int32_t kernel_w = weight->getShape()[2];
1975
TatWai Chong24594f52022-06-08 00:48:04 -07001976 int32_t out_pad_top = attribute->out_pad()[0];
1977 int32_t out_pad_bottom = attribute->out_pad()[1];
1978 int32_t out_pad_left = attribute->out_pad()[2];
1979 int32_t out_pad_right = attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001980
Eric Kunzec1a97832022-07-01 16:56:09 -07001981 for (size_t i = 0; i < attribute->out_pad().size(); i++)
1982 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001983 ERROR_IF(attribute->out_pad()[i] <= -(weight->getShape()[(i / 2) + 1]),
1984 "OpTransposeConv2d: At least one out_pad value is larger than kernel size");
Eric Kunzec1a97832022-07-01 16:56:09 -07001985 }
1986
1987 int32_t H = (IH - 1) * stride_y + out_pad_top + out_pad_bottom + kernel_h;
1988 int32_t W = (IW - 1) * stride_x + out_pad_left + out_pad_right + kernel_w;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001989
1990 if ((OH != H) || (OW != W))
1991 {
1992 std::string msg = "OpTransposeConv2d: Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001993 std::to_string(H) + "," + std::to_string(W) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001994 printNodeValidationError(msg.c_str());
1995 return 1;
1996 }
1997
Tai Lya4d748b2023-03-28 22:06:56 +00001998 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
1999 "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
2000 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
2001 "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07002002
Eric Kunzee5e26762020-10-13 16:11:07 -07002003 return 0;
2004}
2005
Tai Lya4d748b2023-03-28 22:06:56 +00002006template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00002007int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07002008{
2009 int in_batch = this->input->getShape()[0];
2010 int in_height = this->input->getShape()[1];
2011 int in_width = this->input->getShape()[2];
2012 int in_channels = this->input->getShape()[3];
2013
2014 int f_out_channels = this->weight->getShape()[0];
2015 int f_height = this->weight->getShape()[1];
2016 int f_width = this->weight->getShape()[2];
2017 int f_in_channels = this->weight->getShape()[3];
2018
2019 int b_out_channels = this->bias->getShape()[0];
2020
2021 int out_batch = this->output->getShape()[0];
2022 int out_height = this->output->getShape()[1];
2023 int out_width = this->output->getShape()[2];
2024 int out_channels = this->output->getShape()[3];
2025
TatWai Chong24594f52022-06-08 00:48:04 -07002026 int out_pad_top = this->attribute->out_pad()[0];
2027 int out_pad_bottom = this->attribute->out_pad()[1];
2028 int out_pad_left = this->attribute->out_pad()[2];
2029 int out_pad_right = this->attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002030
Jerry Gea793f462023-04-11 00:05:02 +00002031 int stride_y = this->attribute->stride()[0];
2032 int stride_x = this->attribute->stride()[1];
Eric Kunzee5e26762020-10-13 16:11:07 -07002033
Kevin Chengacb550f2021-06-29 15:32:19 -07002034 ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
2035 ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
2036 in_channels);
2037 ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
2038 f_out_channels, out_channels);
Tai Lya641dd52023-08-11 19:58:50 +00002039 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
2040 "OpTransposeConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07002041
Jerry Gea793f462023-04-11 00:05:02 +00002042 // Check Tosa Level
2043 auto tosa_level = g_func_config.tosa_level;
2044 LEVEL_CHECK(f_height <= tosa_level.MAX_KERNEL, "KH should be smaller than or equal to MAX_KERNEL");
2045 LEVEL_CHECK(f_width <= tosa_level.MAX_KERNEL, "KW should be smaller than or equal to MAX_KERNEL");
2046 LEVEL_CHECK(out_pad_top <= tosa_level.MAX_KERNEL, "out_pad_top should be smaller than or equal to MAX_KERNEL");
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002047 LEVEL_CHECK(out_pad_bottom <= tosa_level.MAX_KERNEL,
2048 "out_pad_bottom should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00002049 LEVEL_CHECK(out_pad_left <= tosa_level.MAX_KERNEL, "out_pad_left should be smaller than or equal to MAX_KERNEL");
2050 LEVEL_CHECK(out_pad_right <= tosa_level.MAX_KERNEL, "out_pad_right should be smaller than or equal to MAX_KERNEL");
2051 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
2052 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
2053
Eric Kunzee5e26762020-10-13 16:11:07 -07002054 DEBUG_INFO(OP,
2055 "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00002056 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002057 in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch,
2058 out_height, out_width, out_channels, stride_y, stride_x, out_pad_top, out_pad_bottom, out_pad_left,
2059 out_pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07002060
2061 TIn input_val = this->input->getTensor();
2062 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00002063 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002064 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002065 input_val = input_val - (InEigenType)attribute->input_zp();
2066 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07002067 }
2068
Tai Ly307392a2023-05-12 21:42:19 +00002069 TBias bias_val = this->bias->getTensor();
2070
2071 if (g_func_config.abs_mode)
2072 {
2073 // in abs_mode: take abs values of conv operands
2074 input_val = input_val.abs();
2075 weight_val = weight_val.abs();
2076 bias_val = bias_val.abs();
2077 }
2078
Eric Kunzee5e26762020-10-13 16:11:07 -07002079 Eigen::array<Eigen::Index, 4> reshape_dim;
2080 reshape_dim.fill(1);
2081 reshape_dim[3] = b_out_channels;
2082
2083 Eigen::array<Eigen::Index, 4> bcast;
2084 bcast[0] = out_batch;
2085 bcast[1] = out_height;
2086 bcast[2] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +00002087 bcast[3] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07002088
2089 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00002090 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07002091
2092 int out_x_origin, out_y_origin;
2093 int out_x, out_y;
2094
2095 // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
2096 for (int ob = 0; ob < out_batch; ob++)
2097 {
2098 for (int ih = 0; ih < in_height; ih++)
2099 {
2100 for (int iw = 0; iw < in_width; iw++)
2101 {
Jerry Gea793f462023-04-11 00:05:02 +00002102 out_x_origin = iw * stride_x + out_pad_left;
2103 out_y_origin = ih * stride_y + out_pad_top;
Eric Kunzee5e26762020-10-13 16:11:07 -07002104 for (int ic = 0; ic < in_channels; ic++)
2105 {
2106 for (int fh = 0; fh < f_height; fh++)
2107 {
2108 for (int fw = 0; fw < f_width; fw++)
2109 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002110 out_x = out_x_origin + fw;
2111 out_y = out_y_origin + fh;
Eric Kunzee5e26762020-10-13 16:11:07 -07002112 for (int oc = 0; oc < out_channels; oc++)
2113 {
2114 if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
2115 {
2116 this->output->getTensor()(ob, out_y, out_x, oc) +=
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002117 (OutEigenType)((AccEigenType)input_val(ob, ih, iw, ic) *
2118 (AccEigenType)weight_val(oc, fh, fw, ic));
Eric Kunzee5e26762020-10-13 16:11:07 -07002119 }
2120 }
2121 }
2122 }
2123 }
2124 }
2125 }
2126 }
2127
Tai Lya4d748b2023-03-28 22:06:56 +00002128 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07002129 {
James Ward8b390432022-08-12 20:48:56 +01002130 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
2131 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07002132 }
2133
2134 return GraphNode::eval();
2135}
2136
2137// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +01002138DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002139DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002140DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002141DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002142DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002143DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002144
James Wardd34b3fc2023-01-18 14:51:25 +00002145DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
2146DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
2147DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
2148DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
2149DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
2150DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +00002151DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002152
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002153// [in_t, weight_t, out_t]
James Wardd34b3fc2023-01-18 14:51:25 +00002154DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
2155DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
2156DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
2157DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
2158DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
2159DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
2160DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002161DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002162
James Wardd34b3fc2023-01-18 14:51:25 +00002163DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
2164DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
2165DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
2166DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
2167DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
2168DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
2169DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002170DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
Kevin Cheng1533b852021-09-01 12:51:58 -07002171
James Wardd34b3fc2023-01-18 14:51:25 +00002172DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
2173DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
2174DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
2175DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
2176DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
2177DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
2178DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002179DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002180
Luke Hutton57287132023-02-06 14:54:18 +00002181DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002182DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
Luke Hutton57287132023-02-06 14:54:18 +00002183
James Wardd34b3fc2023-01-18 14:51:25 +00002184DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
2185DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
2186DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
2187DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
2188DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
2189DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
2190DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002191DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002192
James Wardd34b3fc2023-01-18 14:51:25 +00002193DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
2194DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
2195DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
2196DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
2197DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
2198DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002199DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002200
James Ward8b390432022-08-12 20:48:56 +01002201DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002202DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002203DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002204DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002205DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002206DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002207
Luke Hutton261b7b62023-01-10 14:50:31 +00002208DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002209DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
Luke Hutton261b7b62023-01-10 14:50:31 +00002210
James Wardd34b3fc2023-01-18 14:51:25 +00002211DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
2212DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
2213DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
2214DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
2215DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
2216DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
2217DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002218DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);