blob: 2d54d8eb9756cd276094df1ad2c0f14a38d59af7 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Hutton261b7b62023-01-10 14:50:31 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tensor_ops.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000017#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070018#include "quant_util.h"
19#include "template_types.h"
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
Kevin Cheng9fe17242021-11-10 01:04:39 +000025int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
26 std::vector<int32_t> input_shape,
27 std::vector<int32_t> output_shape,
28 std::string& msg)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000029{
TatWai Chong86c403b2022-06-06 20:46:01 -070030 if (attribute->pad().size() != 4)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000031 {
32 msg = "illegal size for attribute padding";
33 return 1;
34 }
35
36 if (attribute->kernel().size() != 2)
37 {
38 msg = "illegal size for attribute kernel";
39 return 1;
40 }
41
42 if (attribute->stride().size() != 2)
43 {
44 msg = "illegal size for attribute stride";
45 return 1;
46 }
47
TatWai Chong86c403b2022-06-06 20:46:01 -070048 for (int32_t i : attribute->pad())
Kevin Cheng7eb93d72021-10-09 01:26:08 +000049 {
50 if (i < 0)
51 {
52 msg = "At least one pad is smaller than zero";
53 return 1;
54 }
55 }
56
57 for (int32_t i : attribute->kernel())
58 {
59 if (i < 1)
60 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000061 msg = "At least one kernel dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000062 return 1;
63 }
64 }
65
66 for (int32_t i : attribute->stride())
67 {
68 if (i < 1)
69 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000070 msg = "At least one stride dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000071 return 1;
72 }
73 }
74
75 int32_t IH = input_shape[1];
76 int32_t IW = input_shape[2];
77 int32_t OH = output_shape[1];
78 int32_t OW = output_shape[2];
79
TatWai Chong86c403b2022-06-06 20:46:01 -070080 int32_t pad_top = attribute->pad()[0];
81 int32_t pad_bottom = attribute->pad()[1];
82 int32_t pad_left = attribute->pad()[2];
83 int32_t pad_right = attribute->pad()[3];
Kevin Cheng7eb93d72021-10-09 01:26:08 +000084
85 int32_t stride_y = attribute->stride()[0];
86 int32_t stride_x = attribute->stride()[1];
87 int32_t kernel_y = attribute->kernel()[0];
88 int32_t kernel_x = attribute->kernel()[1];
89
90 if (pad_top >= kernel_y || pad_bottom >= kernel_y || pad_left >= kernel_x || pad_right >= kernel_x)
91 {
92 msg = "At least one pad is >= kernel dimension";
93 return 1;
94 }
95
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010096 int32_t full_H = IH + pad_top + pad_bottom - kernel_y;
97 int32_t full_W = IW + pad_left + pad_right - kernel_x;
98
Jerry Ge9c9c8da2023-07-19 23:08:16 +000099 if ((full_H % stride_y != 0) || (full_W % stride_x != 0))
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000100 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100101 msg = "Parameters must yield exact integer output dimensions";
102 return 1;
103 }
104
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000105 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100106 {
107 msg = "Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000108 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000109 return 1;
110 }
111
112 return 0;
113}
114
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000115int check_conv_attribute(tosa::TosaConvAttribute* attribute,
Tai Lya4d748b2023-03-28 22:06:56 +0000116 uint32_t conv_dimension,
117 std::vector<int32_t> input_shape,
118 std::vector<int32_t> output_shape,
119 std::vector<int32_t> weights,
120 uint32_t offset_kernel,
121 TOSA_REF_TYPE InDtype,
122 TOSA_REF_TYPE WeightDtype,
123 std::string& msg)
Kevin Cheng9fe17242021-11-10 01:04:39 +0000124{
TatWai Chong86c403b2022-06-06 20:46:01 -0700125 if (attribute->pad().size() != (2 * conv_dimension))
Kevin Cheng9fe17242021-11-10 01:04:39 +0000126 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700127 msg = "Illegal size for attribute pad";
Kevin Cheng9fe17242021-11-10 01:04:39 +0000128 return 1;
129 }
130
131 if (attribute->stride().size() != conv_dimension)
132 {
133 msg = "Illegal size for attribute stride";
134 return 1;
135 }
136
137 if (attribute->dilation().size() != conv_dimension)
138 {
139 msg = "Illegal size for attribute dilation";
140 return 1;
141 }
142
TatWai Chong86c403b2022-06-06 20:46:01 -0700143 for (int32_t i : attribute->pad())
Kevin Cheng9fe17242021-11-10 01:04:39 +0000144 {
145 if (i < 0)
146 {
147 msg = "At least one pad is smaller than zero";
148 return 1;
149 }
150 }
151
152 for (int32_t i : attribute->stride())
153 {
154 if (i < 1)
155 {
156 msg = "At least one stride dimension is smaller than one";
157 return 1;
158 }
159 }
160
161 for (int32_t i : attribute->dilation())
162 {
163 if (i < 1)
164 {
165 msg = "At least one dilation dimension is smaller than one";
166 return 1;
167 }
168 }
169
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100170 ASSERT_MSG(conv_dimension == 2 || conv_dimension == 3, "Unsupported convolution dimension")
171
TatWai Chongfd629052022-07-25 04:01:58 +0000172 int32_t offset_d = conv_dimension == 3 ? 1 : 0;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000173 int32_t ID = conv_dimension == 3 ? input_shape[1] : 1;
174 int32_t IH = input_shape[1 + offset_d];
175 int32_t IW = input_shape[2 + offset_d];
176 int32_t OD = conv_dimension == 3 ? output_shape[1] : 1;
177 int32_t OH = output_shape[1 + offset_d];
178 int32_t OW = output_shape[2 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100179
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000180 int32_t stride_d = conv_dimension == 3 ? attribute->stride()[0] : 1;
181 int32_t stride_y = attribute->stride()[0 + offset_d];
182 int32_t stride_x = attribute->stride()[1 + offset_d];
183 int32_t kernel_d = conv_dimension == 3 ? weights[offset_kernel] : 1;
184 int32_t kernel_h = weights[offset_kernel + offset_d];
185 int32_t kernel_w = weights[offset_kernel + 1 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100186 int32_t dilation_d = conv_dimension == 3 ? attribute->dilation()[0] : 1;
187 int32_t dilation_y = attribute->dilation()[0 + offset_d];
188 int32_t dilation_x = attribute->dilation()[1 + offset_d];
189
190 offset_d *= 2;
TatWai Chong86c403b2022-06-06 20:46:01 -0700191 int32_t pad_d0 = conv_dimension == 3 ? attribute->pad()[0] : 0;
192 int32_t pad_d1 = conv_dimension == 3 ? attribute->pad()[1] : 0;
193 int32_t pad_top = attribute->pad()[0 + offset_d];
194 int32_t pad_bottom = attribute->pad()[1 + offset_d];
195 int32_t pad_left = attribute->pad()[2 + offset_d];
196 int32_t pad_right = attribute->pad()[3 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100197
198 int32_t full_D = ID - 1 + pad_d0 + pad_d1 - (kernel_d - 1) * dilation_d;
199 int32_t full_H = IH - 1 + pad_top + pad_bottom - (kernel_h - 1) * dilation_y;
200 int32_t full_W = IW - 1 + pad_left + pad_right - (kernel_w - 1) * dilation_x;
201
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000202 if ((full_H % stride_y != 0) || (full_W % stride_x != 0) || (full_D % stride_d != 0))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100203 {
204 msg = "Parameters must yield exact integer output dimensions";
205 return 1;
206 }
207
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000208 if ((OH != (full_H / stride_y) + 1) || (OW != (full_W / stride_x) + 1) || (OD != (full_D / stride_d) + 1))
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100209 {
210 std::string msg_d = "";
211 if (conv_dimension == 3)
212 {
213 msg_d += std::to_string((full_D / stride_d) + 1) + ",";
214 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000215 msg = "Mismatch between output shape provided and expected output shape (" + msg_d +
216 std::to_string((full_H / stride_y) + 1) + "," + std::to_string((full_W / stride_x) + 1) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100217 return 1;
218 }
219
Tai Lya4d748b2023-03-28 22:06:56 +0000220 if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0)
221 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000222 msg = "Input zero point must be zero for non-int8 data";
223 return 1;
224 }
Tai Lya4d748b2023-03-28 22:06:56 +0000225 if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0)
226 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 msg = "Weight zero point must be zero for non-int8 data";
228 return 1;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000229 }
230
231 return 0;
232}
233
Luke Hutton57287132023-02-06 14:54:18 +0000234int check_fft_shape(const std::vector<int32_t>& in_real,
235 const std::vector<int32_t>& in_imag,
236 const std::vector<int32_t>& out_real,
237 const std::vector<int32_t>& out_imag,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000238 std::string& msg)
239{
240 const bool is_rfft = in_imag.empty();
241 auto is_power_of_two = [](int32_t n) -> bool { return (n & (n - 1)) == 0 && n > 0; };
Luke Hutton57287132023-02-06 14:54:18 +0000242
243 if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2]))
244 {
245 msg = "Input height and width must be a power of two";
246 return 1;
247 }
248
249 // RFFT does not have a second input
250 if (!is_rfft)
251 {
252 bool input_check = true;
253 for (size_t i = 0; i < in_real.size(); i++)
254 {
255 if (in_real[i] != in_imag[i])
256 {
257 input_check = false;
258 break;
259 }
260 }
261 if (!input_check)
262 {
263 msg = "Mismatch between real input shape and imaginary input shape";
264 return 1;
265 }
266 }
267
268 bool output_check = true;
269 for (size_t i = 0; i < out_real.size(); i++)
270 {
271 if (out_real[i] != out_imag[i])
272 {
273 output_check = false;
274 break;
275 }
276 }
277 if (!output_check)
278 {
279 msg = "Mismatch between real output shape and imaginary output shape";
280 return 1;
281 }
282
283 if (in_real[0] != out_real[0])
284 {
285 msg = "Input and output batch size don't match";
286 return 1;
287 }
288 if (in_real[1] != out_real[1])
289 {
290 msg = "Input and output height don't match";
291 return 1;
292 }
293
294 if (is_rfft)
295 {
296 if (in_real[2] / 2 + 1 != out_real[2])
297 {
298 msg = "Output width is expected to match input width / 2 + 1";
299 return 1;
300 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000301 }
302 else
303 {
Luke Hutton57287132023-02-06 14:54:18 +0000304 if (in_real[2] != out_real[2])
305 {
306 msg = "Input and output width don't match";
307 return 1;
308 }
309 }
310
311 return 0;
312}
313
Tai Lya4d748b2023-03-28 22:06:56 +0000314template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000315OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700316 : GraphNode(sgt_, Op_ARGMAX, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700317{
318 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000319 setRequiredRank(1);
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
321 INIT_ATTRIBUTE(Axis);
322}
323
Tai Lya4d748b2023-03-28 22:06:56 +0000324template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700325OpArgMax<Rank, Dtype>::~OpArgMax()
326{
327 if (attribute)
328 delete attribute;
329}
330
Tai Lya4d748b2023-03-28 22:06:56 +0000331template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700332int OpArgMax<Rank, Dtype>::checkTensorAttributes()
333{
334 if (validateRequiredOperands())
335 return 1;
336
Kevin Chengcc61be32021-10-14 17:09:57 -0700337 if (validateRequiredRank(inputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700338 {
339 return 1;
340 }
341
Kevin Chengcc61be32021-10-14 17:09:57 -0700342 int32_t output_rank = inputs[0]->getRank() - 1;
343 if (output_rank != outputs[0]->getRank())
344 {
345 printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
346 return 1;
347 }
348
Tai Lya4d748b2023-03-28 22:06:56 +0000349 if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32)
Kevin Chengcc61be32021-10-14 17:09:57 -0700350 {
351 printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
352 return 1;
353 }
354
Eric Kunzee5e26762020-10-13 16:11:07 -0700355 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
356 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
357
Kevin Chengcc61be32021-10-14 17:09:57 -0700358 if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
359 {
360 printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
361 return 1;
362 }
363
364 bool shape_check = true;
365 for (int32_t i = 0; i < input->getRank(); i++)
366 {
367 if (i < attribute->axis())
368 {
369 if (input->getShape()[i] != output->getShape()[i])
370 {
371 shape_check = false;
372 break;
373 }
374 }
375 else if (i > attribute->axis())
376 {
377 if (input->getShape()[i] != output->getShape()[i - 1])
378 {
379 shape_check = false;
380 break;
381 }
382 }
383 // No need to check i == axis
384 }
385 if (!shape_check)
386 {
387 printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
388 return 1;
389 }
390
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 return 0;
392}
393
Tai Lya4d748b2023-03-28 22:06:56 +0000394template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700395int OpArgMax<Rank, Dtype>::eval()
396{
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000397 // Check Tosa Level
398 auto tosa_level = g_func_config.tosa_level;
399 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
400
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
402
403 this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
404
405 return GraphNode::eval();
406}
407
Tai Lya4d748b2023-03-28 22:06:56 +0000408template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000409OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700410 : GraphNode(sgt_, Op_AVG_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700411{
412 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000413 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700414
Kevin Cheng93a16282021-08-31 16:14:03 -0700415 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -0700416}
417
Tai Lya4d748b2023-03-28 22:06:56 +0000418template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100419OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700420{
421 if (attribute)
422 delete attribute;
423}
424
Tai Lya4d748b2023-03-28 22:06:56 +0000425template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100426int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700427{
428 if (validateRequiredOperands())
429 return 1;
430
431 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
432 {
433 return 1;
434 }
435
436 if (inputs[0]->matchType(*outputs[0]))
437 {
438 printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
439 return 1;
440 }
441
442 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
443 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
444
Tai Lya4d748b2023-03-28 22:06:56 +0000445 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
446 "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
447 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0,
448 "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
Eric Kunzee5e26762020-10-13 16:11:07 -0700449
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000450 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000451 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700452 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000453 msg = "OpAvgPool2d: " + msg;
454 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return 1;
456 }
457
458 return 0;
459}
460
Eric Kunzee5e26762020-10-13 16:11:07 -0700461// assuming input and output tensor have same scales like tflite reference
462// so no need to scale input and output
Tai Lya4d748b2023-03-28 22:06:56 +0000463template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
James Ward8b390432022-08-12 20:48:56 +0100464int OpAvgPool2d<Dtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700465{
466 int in_batch = this->in->getShape()[0];
467 int in_height = this->in->getShape()[1];
468 int in_width = this->in->getShape()[2];
469 int in_channels = this->in->getShape()[3];
470
471 int out_batch = this->out->getShape()[0];
472 int out_height = this->out->getShape()[1];
473 int out_width = this->out->getShape()[2];
474 int out_channels = this->out->getShape()[3];
475
Kevin Chengacb550f2021-06-29 15:32:19 -0700476 ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
477 ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700478
TatWai Chong86c403b2022-06-06 20:46:01 -0700479 int pad_top = this->attribute->pad()[0];
480 int pad_bottom = this->attribute->pad()[1];
481 int pad_left = this->attribute->pad()[2];
482 int pad_right = this->attribute->pad()[3];
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000483 int kernel_y = this->attribute->kernel()[0];
484 int kernel_x = this->attribute->kernel()[1];
485 int stride_y = this->attribute->stride()[0];
486 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000487
488 // Check Tosa Level
489 auto tosa_level = g_func_config.tosa_level;
490 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
491 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
492 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
493 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
494 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
495 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
496 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
497 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -0700498
Tai Lya4d748b2023-03-28 22:06:56 +0000499 TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype());
James Ward8b390432022-08-12 20:48:56 +0100500
Eric Kunzee5e26762020-10-13 16:11:07 -0700501 DEBUG_INFO(OP,
502 "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
James Ward8b390432022-08-12 20:48:56 +0100503 "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
Jerry Gea793f462023-04-11 00:05:02 +0000504 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
505 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700506
TatWai Chong86c403b2022-06-06 20:46:01 -0700507 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
508 pad[0] = std::make_pair(0, 0);
509 pad[1] = std::make_pair(pad_top, pad_bottom);
510 pad[2] = std::make_pair(pad_left, pad_right);
511 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700512
513 ETensor4<InEigenType> input_val = this->in->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000514 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700515 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000516 input_val = input_val - (InEigenType)attribute->input_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700517 }
518
Tai Ly307392a2023-05-12 21:42:19 +0000519 if (g_func_config.abs_mode)
520 {
James Ward29e46cf2023-10-23 11:47:25 +0000521 // in abs_mode: take abs values of input_val
522 input_val = input_val.abs();
Tai Ly307392a2023-05-12 21:42:19 +0000523 }
524
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 // assuming input and output have same scales
526 // so input and output scaling is not required
527 // TODO: check if this assumption TOSA made
528
James Ward5a9e0cd2023-10-09 16:51:26 +0000529 ETensor4<OutEigenType> out_tens(out_batch, out_height, out_width, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700530
531 // sum pool
James Ward5a9e0cd2023-10-09 16:51:26 +0000532 for (int ob = 0; ob < out_batch; ++ob)
Eric Kunzee5e26762020-10-13 16:11:07 -0700533 {
James Ward5a9e0cd2023-10-09 16:51:26 +0000534 for (int oh = 0; oh < out_height; ++oh)
Eric Kunzee5e26762020-10-13 16:11:07 -0700535 {
James Ward5a9e0cd2023-10-09 16:51:26 +0000536 for (int ow = 0; ow < out_width; ++ow)
537 {
538 for (int oc = 0; oc < out_channels; ++oc)
539 {
540 AccEigenType acc(0);
541 int filter_count = 0;
542 const int iy = oh * stride_y - pad_top;
543 const int ix = ow * stride_x - pad_left;
544 for (int ky = 0; ky < kernel_y; ++ky)
545 {
546 for (int kx = 0; kx < kernel_x; ++kx)
547 {
548 const int y = iy + ky;
549 const int x = ix + kx;
550 if ((0 <= y && y < in_height) && (0 <= x && x < in_width))
551 {
552 ++filter_count;
James Ward29e46cf2023-10-23 11:47:25 +0000553 acc = acc + (AccEigenType)input_val(ob, y, x, oc);
James Ward5a9e0cd2023-10-09 16:51:26 +0000554 }
555 }
556 }
557 if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
558 Dtype != TOSA_REF_TYPE_FP64)
559 {
560 try
561 {
562 int32_t multiplier, shift;
563 OutEigenType out;
564 TosaReference::QuantUtil::reciprocal_scale(filter_count, multiplier, shift);
565
566 out = (OutEigenType)TosaReference::QuantUtil::apply_scale_32(acc, multiplier, shift, false);
567 out = out + (OutEigenType)(attribute->output_zp());
568 out = std::max(out, (OutEigenType)QMin);
569 out_tens(ob, oh, ow, oc) = std::min(out, (OutEigenType)QMax);
570 }
571 catch (std::string desc)
572 {
573 REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
574 }
575 }
576 else
577 {
578 REQUIRE(filter_count != 0, "OpAvgPool2d number of filters should be non-zero.");
579 out_tens(ob, oh, ow, oc) = acc / static_cast<OutEigenType>(filter_count);
580 }
581 }
582 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700583 }
584 }
James Ward5a9e0cd2023-10-09 16:51:26 +0000585 this->out->getTensor() = out_tens;
Eric Kunzee5e26762020-10-13 16:11:07 -0700586 return GraphNode::eval();
587}
588
Tai Lya4d748b2023-03-28 22:06:56 +0000589template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000590OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700591 : GraphNode(sgt_, Op_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700592{
593 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000594 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700595
Kevin Cheng93a16282021-08-31 16:14:03 -0700596 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700597}
598
Tai Lya4d748b2023-03-28 22:06:56 +0000599template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000600OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700601{
602 if (attribute)
603 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700604}
605
Tai Lya4d748b2023-03-28 22:06:56 +0000606template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000607int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700608{
609 if (validateRequiredOperands())
610 return 1;
611
612 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
613 {
614 return 1;
615 }
616
617 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
618 if (inputs[2]->getRank() != 1)
619 {
620 printNodeValidationError("OpConv2d: bias tensor must be rank 1");
621 }
622
James Wardd34b3fc2023-01-18 14:51:25 +0000623 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000624 "OpConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700625
Eric Kunzee5e26762020-10-13 16:11:07 -0700626 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
627 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
628 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100629 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700630
Kevin Cheng9fe17242021-11-10 01:04:39 +0000631 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000632 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000633 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700634 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000635 msg = "OpConv2d: " + msg;
636 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700637 return 1;
638 }
639
Eric Kunzee5e26762020-10-13 16:11:07 -0700640 return 0;
641}
642
Tai Lya4d748b2023-03-28 22:06:56 +0000643template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000644int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700645{
646 int in_batch = this->input->getShape()[0];
647 int in_height = this->input->getShape()[1];
648 int in_width = this->input->getShape()[2];
649 int in_channels = this->input->getShape()[3];
650
651 int f_out_channels = this->weight->getShape()[0];
652 int f_height = this->weight->getShape()[1];
653 int f_width = this->weight->getShape()[2];
654 int f_in_channels = this->weight->getShape()[3];
655
656 int b_out_channels = this->bias->getShape()[0];
657
658 int out_batch = this->output->getShape()[0];
659 int out_height = this->output->getShape()[1];
660 int out_width = this->output->getShape()[2];
661 int out_channels = this->output->getShape()[3];
662
Kevin Chengacb550f2021-06-29 15:32:19 -0700663 ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
664 ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
665 in_channels);
666 ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
667 out_channels);
Tai Lya641dd52023-08-11 19:58:50 +0000668 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv2d: bias channel mismatch %d != %d",
669 b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700670
TatWai Chong86c403b2022-06-06 20:46:01 -0700671 int pad_top = this->attribute->pad()[0];
672 int pad_bottom = this->attribute->pad()[1];
673 int pad_left = this->attribute->pad()[2];
674 int pad_right = this->attribute->pad()[3];
675
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000676 int stride_y = this->attribute->stride()[0];
677 int stride_x = this->attribute->stride()[1];
678 int dilation_y = this->attribute->dilation()[0];
679 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +0000680
681 // Check Tosa Level
682 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000683 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
684 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
685 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
686 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000687 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
688 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
689 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
690 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
691 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
692 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -0700693
694 DEBUG_INFO(OP,
695 "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000696 "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -0700697 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000698 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
699 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700700
701 // GEMM-conv2d, left matrix is input, right matrix is weight
702 Eigen::array<Eigen::Index, 2> im2col_input_dims;
703 im2col_input_dims[0] = out_batch * out_height * out_width;
704 im2col_input_dims[1] = f_height * f_width * f_in_channels;
705
706 Eigen::array<Eigen::Index, 2> im2col_weight_dims;
707 im2col_weight_dims[0] = f_height * f_width * f_in_channels;
708 im2col_weight_dims[1] = f_out_channels;
709
710 Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
711 bias_reshaped_dims[0] = 1;
712 bias_reshaped_dims[1] = b_out_channels;
713
714 Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
715 weight_zp_bcast_dims[0] = f_height;
716 weight_zp_bcast_dims[1] = f_width;
717 weight_zp_bcast_dims[2] = f_in_channels;
718
719 Eigen::array<Eigen::Index, 2> bias_bcast_dims;
720 bias_bcast_dims[0] = out_batch * out_height * out_width;
Tai Lya641dd52023-08-11 19:58:50 +0000721 bias_bcast_dims[1] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700722
723 Eigen::array<Eigen::Index, 4> col2im_output_dims;
724 col2im_output_dims[0] = out_batch;
725 col2im_output_dims[1] = out_height;
726 col2im_output_dims[2] = out_width;
727 col2im_output_dims[3] = out_channels;
728
729 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
730
TatWai Chong86c403b2022-06-06 20:46:01 -0700731 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
732 pad[0] = std::make_pair(0, 0);
733 pad[1] = std::make_pair(pad_top, pad_bottom);
734 pad[2] = std::make_pair(pad_left, pad_right);
735 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700736
737 TIn input_val = this->input->getTensor();
738 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000739 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700740 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000741 input_val = input_val - (InEigenType)attribute->input_zp();
742 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700743 }
744
TatWai Chong86c403b2022-06-06 20:46:01 -0700745 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700746
Tai Ly307392a2023-05-12 21:42:19 +0000747 TBias bias_val = this->bias->getTensor();
748
749 if (g_func_config.abs_mode)
750 {
751 // in abs_mode: take abs values of conv operands
752 input_padded = input_padded.abs();
753 weight_val = weight_val.abs();
754 bias_val = bias_val.abs();
755 }
756
Eric Kunzee5e26762020-10-13 16:11:07 -0700757 // extract_image_patches() output [N, KH, KW, H * W, C]
758 // need to transpose to [N, H * W, KH, KW, C]
759 ETensor5<InEigenType> input_extract_patches =
760 input_padded
Jerry Gea793f462023-04-11 00:05:02 +0000761 .extract_image_patches(f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700762 .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
763
764 // reshape input to [N * H * W, KH * KW * C]
765 ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
766
767 // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
768 ETensor2<WeightEigenType> im2col_weight =
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000769 weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
Eric Kunzee5e26762020-10-13 16:11:07 -0700770
771 // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
772 // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
Tai Ly307392a2023-05-12 21:42:19 +0000773 ETensor2<OutEigenType> bias_2d =
774 (bias_val.reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700775
776 // output matrix is [N * H * W, C]
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000777 ETensor2<OutEigenType> contracted_result = (im2col_input.template cast<AccEigenType>().contract(
778 im2col_weight.template cast<AccEigenType>(), contract_dims))
779 .template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700780
781 // adding bias
James Ward8b390432022-08-12 20:48:56 +0100782 ETensor2<OutEigenType> biased_output = contracted_result + bias_2d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700783
784 // reshape back to [N, H, W, C]
785 this->output->getTensor() = biased_output.reshape(col2im_output_dims);
786
Tai Lya4d748b2023-03-28 22:06:56 +0000787 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -0700788 {
James Ward8b390432022-08-12 20:48:56 +0100789 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
790 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700791 }
792
793 return GraphNode::eval();
794}
795
Tai Lya4d748b2023-03-28 22:06:56 +0000796template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000797OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Cheng1533b852021-09-01 12:51:58 -0700798 : GraphNode(sgt_, Op_CONV3D, id_)
799{
800 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000801 setRequiredRank(5, 5);
Kevin Cheng1533b852021-09-01 12:51:58 -0700802
803 INIT_ATTRIBUTE(Conv);
Kevin Cheng1533b852021-09-01 12:51:58 -0700804}
805
Tai Lya4d748b2023-03-28 22:06:56 +0000806template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000807OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
Kevin Cheng1533b852021-09-01 12:51:58 -0700808{
809 if (attribute)
810 delete attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700811}
812
Tai Lya4d748b2023-03-28 22:06:56 +0000813template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000814int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Kevin Cheng1533b852021-09-01 12:51:58 -0700815{
816 if (validateRequiredOperands())
817 return 1;
818
819 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
820 {
821 return 1;
822 }
823
824 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
825 if (inputs[2]->getRank() != 1)
826 {
827 printNodeValidationError("OpConv3d: bias tensor must be rank 1");
828 }
829
James Wardd34b3fc2023-01-18 14:51:25 +0000830 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000831 "OpConv3d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700832
Kevin Cheng1533b852021-09-01 12:51:58 -0700833 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
834 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
835 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100836 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng1533b852021-09-01 12:51:58 -0700837
Kevin Cheng9fe17242021-11-10 01:04:39 +0000838 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000839 if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000840 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Kevin Cheng1533b852021-09-01 12:51:58 -0700841 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000842 msg = "OpConv3d: " + msg;
843 printNodeValidationError(msg.c_str());
Kevin Cheng1533b852021-09-01 12:51:58 -0700844 return 1;
845 }
846
Kevin Cheng1533b852021-09-01 12:51:58 -0700847 return 0;
848}
849
Tai Lya4d748b2023-03-28 22:06:56 +0000850template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +0000851int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
Kevin Cheng1533b852021-09-01 12:51:58 -0700852{
853 int in_batch = this->input->getShape()[0];
854 int in_depth = this->input->getShape()[1];
855 int in_height = this->input->getShape()[2];
856 int in_width = this->input->getShape()[3];
857 int in_channels = this->input->getShape()[4];
858
859 int f_out_channels = this->weight->getShape()[0];
860 int f_depth = this->weight->getShape()[1];
861 int f_height = this->weight->getShape()[2];
862 int f_width = this->weight->getShape()[3];
863 int f_in_channels = this->weight->getShape()[4];
864
865 int b_out_channels = this->bias->getShape()[0];
866
867 int out_batch = this->output->getShape()[0];
868 int out_depth = this->output->getShape()[1];
869 int out_height = this->output->getShape()[2];
870 int out_width = this->output->getShape()[3];
871 int out_channels = this->output->getShape()[4];
872
873 ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
874 ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
875 in_channels);
876 ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
877 out_channels);
Tai Lya641dd52023-08-11 19:58:50 +0000878 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv3d: bias channel mismatch %d != %d",
879 b_out_channels, out_channels);
Kevin Cheng1533b852021-09-01 12:51:58 -0700880
TatWai Chong86c403b2022-06-06 20:46:01 -0700881 int pad_d0 = this->attribute->pad()[0];
882 int pad_d1 = this->attribute->pad()[1];
883 int pad_top = this->attribute->pad()[2];
884 int pad_bottom = this->attribute->pad()[3];
885 int pad_left = this->attribute->pad()[4];
886 int pad_right = this->attribute->pad()[5];
887
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000888 int stride_d = this->attribute->stride()[0];
889 int stride_y = this->attribute->stride()[1];
890 int stride_x = this->attribute->stride()[2];
TatWai Chong86c403b2022-06-06 20:46:01 -0700891
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000892 int dilation_d = this->attribute->dilation()[0];
893 int dilation_y = this->attribute->dilation()[1];
894 int dilation_x = this->attribute->dilation()[2];
Jerry Gea793f462023-04-11 00:05:02 +0000895
896 // Check Tosa Level
897 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000898 LEVEL_CHECK(dilation_d * f_depth <= tosa_level.MAX_KERNEL,
899 "dilation_d * KD should be smaller than or equal to MAX_KERNEL");
900 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
901 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
902 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
903 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +0000904 LEVEL_CHECK(pad_d0 <= tosa_level.MAX_KERNEL, "pad_d0 should be smaller than or equal to MAX_KERNEL");
905 LEVEL_CHECK(pad_d1 <= tosa_level.MAX_KERNEL, "pad_d1 should be smaller than or equal to MAX_KERNEL");
906 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
907 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
908 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
909 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
910 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
911 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
912 LEVEL_CHECK(stride_d <= tosa_level.MAX_STRIDE, "stride_d should be smaller than or equal to MAX_STRIDE");
Kevin Cheng1533b852021-09-01 12:51:58 -0700913
914 DEBUG_INFO(
915 OP,
916 "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000917 "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
Kevin Cheng1533b852021-09-01 12:51:58 -0700918 in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
Jerry Gea793f462023-04-11 00:05:02 +0000919 out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_y, stride_x, dilation_d, dilation_y,
920 dilation_x, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
Kevin Cheng1533b852021-09-01 12:51:58 -0700921
TatWai Chong86c403b2022-06-06 20:46:01 -0700922 Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
923 pad[0] = std::make_pair(0, 0);
924 pad[1] = std::make_pair(pad_d0, pad_d1);
925 pad[2] = std::make_pair(pad_top, pad_bottom);
926 pad[3] = std::make_pair(pad_left, pad_right);
927 pad[4] = std::make_pair(0, 0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700928
929 TIn input_val = this->input->getTensor();
930 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +0000931 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Kevin Cheng1533b852021-09-01 12:51:58 -0700932 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000933 input_val = input_val - (InEigenType)attribute->input_zp();
934 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Kevin Cheng1533b852021-09-01 12:51:58 -0700935 }
936
TatWai Chong86c403b2022-06-06 20:46:01 -0700937 ETensor5<InEigenType> input_padded = input_val.pad(pad);
Kevin Cheng1533b852021-09-01 12:51:58 -0700938
Tai Ly307392a2023-05-12 21:42:19 +0000939 TBias bias_val = this->bias->getTensor();
940
941 if (g_func_config.abs_mode)
942 {
943 // in abs_mode: take abs values of conv operands
944 input_padded = input_padded.abs();
945 weight_val = weight_val.abs();
946 bias_val = bias_val.abs();
947 }
948
Kevin Cheng1533b852021-09-01 12:51:58 -0700949 // 1. initialize with bias
950 Eigen::array<Eigen::Index, 5> reshape_dim;
951 reshape_dim.fill(1);
952 reshape_dim[4] = b_out_channels;
953
954 Eigen::array<Eigen::Index, 5> bcast;
955 bcast[0] = out_batch;
956 bcast[1] = out_depth;
957 bcast[2] = out_height;
958 bcast[3] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +0000959 bcast[4] = (b_out_channels == 1) ? out_channels : 1;
Tai Ly307392a2023-05-12 21:42:19 +0000960 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Kevin Cheng1533b852021-09-01 12:51:58 -0700961
962 // 2. direct convolution
James Ward8b390432022-08-12 20:48:56 +0100963 AccEigenType acc(0.0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700964 int d_idx, h_idx, w_idx;
965
966 for (int ob = 0; ob < out_batch; ob++)
967 {
968 for (int od = 0; od < out_depth; od++)
969 {
970 for (int oh = 0; oh < out_height; oh++)
971 {
972 for (int ow = 0; ow < out_width; ow++)
973 {
974 for (int oc = 0; oc < out_channels; oc++)
975 {
Eric Kunze7edb34c2022-05-16 17:34:40 -0700976 // Initialize accumulator with bias value
James Ward8b390432022-08-12 20:48:56 +0100977 acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc);
Kevin Cheng1533b852021-09-01 12:51:58 -0700978 for (int fd = 0; fd < f_depth; fd++)
979 {
980 d_idx = od * stride_d + fd * dilation_d;
981 for (int fh = 0; fh < f_height; fh++)
982 {
Jerry Gea793f462023-04-11 00:05:02 +0000983 h_idx = oh * stride_y + fh * dilation_y;
Kevin Cheng1533b852021-09-01 12:51:58 -0700984 for (int fw = 0; fw < f_width; fw++)
985 {
Jerry Gea793f462023-04-11 00:05:02 +0000986 w_idx = ow * stride_x + fw * dilation_x;
Kevin Cheng1533b852021-09-01 12:51:58 -0700987 for (int ic = 0; ic < in_channels; ic++)
988 {
989 acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
990 (AccEigenType)weight_val(oc, fd, fh, fw, ic));
991 }
992 }
993 }
994 }
James Ward8b390432022-08-12 20:48:56 +0100995 this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc;
Kevin Cheng1533b852021-09-01 12:51:58 -0700996 }
997 }
998 }
999 }
1000 }
1001
Tai Lya4d748b2023-03-28 22:06:56 +00001002 if (OutDtype == TOSA_REF_TYPE_INT48)
Kevin Cheng1533b852021-09-01 12:51:58 -07001003 {
James Ward8b390432022-08-12 20:48:56 +01001004 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1005 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Kevin Cheng1533b852021-09-01 12:51:58 -07001006 }
1007
1008 return GraphNode::eval();
1009}
1010
Tai Lya4d748b2023-03-28 22:06:56 +00001011template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001012OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001013 TosaAttributeBase* attribute_,
1014 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001015 : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001016{
1017 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001018 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001019
Kevin Cheng93a16282021-08-31 16:14:03 -07001020 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001021}
1022
Tai Lya4d748b2023-03-28 22:06:56 +00001023template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001024OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001025{
1026 if (attribute)
1027 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001028}
1029
Tai Lya4d748b2023-03-28 22:06:56 +00001030template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001031int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001032{
1033 if (validateRequiredOperands())
1034 return 1;
1035
1036 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1037 {
1038 return 1;
1039 }
1040
1041 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
1042 if (inputs[2]->getRank() != 1)
1043 {
1044 printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
1045 }
1046
James Wardd34b3fc2023-01-18 14:51:25 +00001047 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001048 "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001049
Eric Kunzee5e26762020-10-13 16:11:07 -07001050 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1051 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1052 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001053 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001054
Kevin Cheng9fe17242021-11-10 01:04:39 +00001055 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001056 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001057 weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001058 {
Kevin Cheng9fe17242021-11-10 01:04:39 +00001059 msg = "OpDepthwiseConv2d: " + msg;
1060 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001061 return 1;
1062 }
1063
Eric Kunzee5e26762020-10-13 16:11:07 -07001064 return 0;
1065}
1066
Tai Lya4d748b2023-03-28 22:06:56 +00001067template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001068int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001069{
1070 int in_batch = this->input->getShape()[0];
1071 int in_height = this->input->getShape()[1];
1072 int in_width = this->input->getShape()[2];
1073 int in_channels = this->input->getShape()[3];
1074
1075 int f_height = this->weight->getShape()[0];
1076 int f_width = this->weight->getShape()[1];
1077 int f_in_channels = this->weight->getShape()[2];
1078 int f_multiplier = this->weight->getShape()[3];
1079
1080 int b_out_channels = this->bias->getShape()[0];
1081
1082 int out_batch = this->output->getShape()[0];
1083 int out_height = this->output->getShape()[1];
1084 int out_width = this->output->getShape()[2];
1085 int out_channels = this->output->getShape()[3];
1086
Kevin Chengacb550f2021-06-29 15:32:19 -07001087 ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1088 ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1089 in_channels);
1090 ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
1091 in_channels * f_multiplier, out_channels);
Tai Lya641dd52023-08-11 19:58:50 +00001092 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
1093 "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001094
TatWai Chong86c403b2022-06-06 20:46:01 -07001095 int pad_top = this->attribute->pad()[0];
1096 int pad_bottom = this->attribute->pad()[1];
1097 int pad_left = this->attribute->pad()[2];
1098 int pad_right = this->attribute->pad()[3];
1099
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001100 int stride_y = this->attribute->stride()[0];
1101 int stride_x = this->attribute->stride()[1];
1102 int dilation_y = this->attribute->dilation()[0];
1103 int dilation_x = this->attribute->dilation()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001104
1105 // Check Tosa Level
1106 auto tosa_level = g_func_config.tosa_level;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001107 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL,
1108 "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
1109 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL,
1110 "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00001111 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1112 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1113 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1114 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
1115 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1116 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -07001117
1118 DEBUG_INFO(OP,
1119 "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00001120 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -07001121 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001122 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top, pad_bottom,
1123 pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001124
TatWai Chong86c403b2022-06-06 20:46:01 -07001125 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1126 pad[0] = std::make_pair(0, 0);
1127 pad[1] = std::make_pair(pad_top, pad_bottom);
1128 pad[2] = std::make_pair(pad_left, pad_right);
1129 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001130
1131 TIn input_val = this->input->getTensor();
1132 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001133 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001134 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001135 input_val = input_val - (InEigenType)attribute->input_zp();
1136 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001137 }
1138
TatWai Chong86c403b2022-06-06 20:46:01 -07001139 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -07001140
Tai Ly307392a2023-05-12 21:42:19 +00001141 TBias bias_val = this->bias->getTensor();
1142
1143 if (g_func_config.abs_mode)
1144 {
1145 // in abs_mode: take abs values of conv operands
1146 input_padded = input_padded.abs();
1147 weight_val = weight_val.abs();
1148 bias_val = bias_val.abs();
1149 }
1150
Eric Kunzee5e26762020-10-13 16:11:07 -07001151 // GEMM doesn't fit well with DepthwiseConv2d
TatWai Chong86c403b2022-06-06 20:46:01 -07001152 // 1. use extract_image_patches() to handle stride/dilation/pad
Eric Kunzee5e26762020-10-13 16:11:07 -07001153 // 2. perform direct convolution
1154
1155 // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
1156 ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
Jerry Gea793f462023-04-11 00:05:02 +00001157 f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID);
Eric Kunzee5e26762020-10-13 16:11:07 -07001158
1159 Eigen::array<Eigen::Index, 4> reshape_dim;
1160 reshape_dim.fill(1);
1161 reshape_dim[3] = b_out_channels;
1162
1163 Eigen::array<Eigen::Index, 4> bcast;
1164 bcast[0] = out_batch;
1165 bcast[1] = out_height;
1166 bcast[2] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +00001167 bcast[3] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07001168
1169 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00001170 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001171
1172 // 2. direct depthwise convolution
1173 for (int ob = 0; ob < out_batch; ob++)
1174 {
1175 for (int oh = 0; oh < out_height; oh++)
1176 {
1177 for (int ow = 0; ow < out_width; ow++)
1178 {
1179 for (int ic = 0; ic < in_channels; ic++)
1180 {
1181 for (int cm = 0; cm < f_multiplier; cm++)
1182 {
1183 for (int fh = 0; fh < f_height; fh++)
1184 {
1185 for (int fw = 0; fw < f_width; fw++)
1186 {
James Ward8b390432022-08-12 20:48:56 +01001187 // Perform multiplication in AccEigenType then cast to OutEigenType
Eric Kunzebe2e87c2023-08-07 15:16:18 +00001188 this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
1189 (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh,
1190 ic) *
1191 (AccEigenType)weight_val(fh, fw, ic, cm));
Eric Kunzee5e26762020-10-13 16:11:07 -07001192 }
1193 }
1194 }
1195 }
1196 }
1197 }
1198 }
1199
Tai Lya4d748b2023-03-28 22:06:56 +00001200 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001201 {
James Ward8b390432022-08-12 20:48:56 +01001202 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1203 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001204 }
1205
1206 return GraphNode::eval();
1207}
1208
Tai Lya4d748b2023-03-28 22:06:56 +00001209template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001210OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001211 TosaAttributeBase* attribute_,
1212 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001213 : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001214{
1215 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001216 setRequiredRank(2, 2);
Eric Kunzee5e26762020-10-13 16:11:07 -07001217
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001218 INIT_ATTRIBUTE(FullyConnected);
Eric Kunzee5e26762020-10-13 16:11:07 -07001219}
1220
Tai Lya4d748b2023-03-28 22:06:56 +00001221template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001222OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
Eric Kunzee5e26762020-10-13 16:11:07 -07001223{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001224 if (attribute)
1225 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001226}
1227
Tai Lya4d748b2023-03-28 22:06:56 +00001228template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001229int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001230{
1231 if (validateRequiredOperands())
1232 return 1;
1233
1234 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1235 {
1236 return 1;
1237 }
1238
1239 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1240 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1241 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
1242
1243 if (input->getShape()[1] != weight->getShape()[1])
1244 {
1245 printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
1246 return 1;
1247 }
1248
1249 if (weight->getShape()[0] != bias->getShape()[0])
1250 {
1251 printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
1252 return 1;
1253 }
1254
James Wardd34b3fc2023-01-18 14:51:25 +00001255 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001256 "OpFullyConnected: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001257
James Ward8b390432022-08-12 20:48:56 +01001258 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001259
Tai Lya4d748b2023-03-28 22:06:56 +00001260 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
1261 "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
1262 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
1263 "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001264
Eric Kunzee5e26762020-10-13 16:11:07 -07001265 return 0;
1266}
1267
Tai Lya4d748b2023-03-28 22:06:56 +00001268template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001269int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001270{
1271 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1272 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1273
1274 Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
1275
Tai Lya641dd52023-08-11 19:58:50 +00001276 int b_out_channels = this->bias->getShape()[0];
1277 int out_channels = this->output->getShape()[1];
1278
1279 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpFullyConnected: bias channels mismatch %d != %d",
1280 b_out_channels, out_channels);
1281
Eric Kunzee5e26762020-10-13 16:11:07 -07001282 Eigen::array<Eigen::Index, 2> bias_reshape;
1283 bias_reshape[0] = 1;
Tai Lya641dd52023-08-11 19:58:50 +00001284 bias_reshape[1] = b_out_channels;
Eric Kunzee5e26762020-10-13 16:11:07 -07001285
1286 Eigen::array<Eigen::Index, 2> bias_bcast;
1287 bias_bcast[0] = this->input->getShape()[0];
Tai Lya641dd52023-08-11 19:58:50 +00001288 bias_bcast[1] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07001289
1290 TIn input_val = this->input->getTensor();
1291 TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
Tai Lya4d748b2023-03-28 22:06:56 +00001292 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001293 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001294 input_val = input_val - (InEigenType)attribute->input_zp();
1295 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001296 }
1297
Tai Ly307392a2023-05-12 21:42:19 +00001298 TBias bias_val = this->bias->getTensor();
1299
1300 if (g_func_config.abs_mode)
1301 {
1302 // in abs_mode: take abs values of conv operands
1303 input_val = input_val.abs();
1304 weight_val = weight_val.abs();
1305 bias_val = bias_val.abs();
1306 }
1307
1308 this->output->getTensor() = input_val.template cast<AccEigenType>()
1309 .contract(weight_val.template cast<AccEigenType>(), dims)
1310 .template cast<OutEigenType>() +
1311 bias_val.reshape(bias_reshape).broadcast(bias_bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001312
Tai Lya4d748b2023-03-28 22:06:56 +00001313 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001314 {
James Ward8b390432022-08-12 20:48:56 +01001315 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1316 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001317 }
1318 return GraphNode::eval();
1319}
1320
Tai Lya4d748b2023-03-28 22:06:56 +00001321template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001322OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001323 : GraphNode(sgt_, Op_MATMUL, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001324{
1325 setRequiredOperands(2, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001326 setRequiredRank(3, 3);
Eric Kunzee5e26762020-10-13 16:11:07 -07001327
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001328 INIT_ATTRIBUTE(MatMul);
Eric Kunzee5e26762020-10-13 16:11:07 -07001329}
1330
Tai Lya4d748b2023-03-28 22:06:56 +00001331template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001332OpMatMul<Dtype, OutDtype>::~OpMatMul()
Eric Kunzee5e26762020-10-13 16:11:07 -07001333{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001334 if (attribute)
1335 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001336}
1337
Tai Lya4d748b2023-03-28 22:06:56 +00001338template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001339int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001340{
1341 if (validateRequiredOperands())
1342 return 1;
1343
1344 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1345 {
1346 return 1;
1347 }
1348
James Wardd34b3fc2023-01-18 14:51:25 +00001349 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001350 "OpMatMul: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001351
Kevin Cheng2d60f002021-06-09 14:18:32 -07001352 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1353 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
James Ward8b390432022-08-12 20:48:56 +01001354 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001355
Kevin Cheng2d60f002021-06-09 14:18:32 -07001356 ASSERT_MEM(a && b && output);
1357
1358 // a: [N, H, C]
1359 // b: [N, C, W]
1360 // c: [N, H, W]
1361
1362 // Check N
1363 if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001364 {
Kevin Cheng2d60f002021-06-09 14:18:32 -07001365 printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
Eric Kunzee5e26762020-10-13 16:11:07 -07001366 return 1;
1367 }
Kevin Cheng2d60f002021-06-09 14:18:32 -07001368 N = a->getShape()[0];
Eric Kunzee5e26762020-10-13 16:11:07 -07001369
Kevin Cheng2d60f002021-06-09 14:18:32 -07001370 // Check C
1371 if (a->getShape()[2] != b->getShape()[1])
1372 {
1373 printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
1374 return 1;
1375 }
1376 C = a->getShape()[2];
1377
1378 // Check H
1379 if (a->getShape()[1] != output->getShape()[1])
1380 {
1381 printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
1382 return 1;
1383 }
1384 H = a->getShape()[1];
1385
1386 // Check W
1387 if (b->getShape()[2] != output->getShape()[2])
1388 {
1389 printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
1390 return 1;
1391 }
1392 W = b->getShape()[2];
Eric Kunzee5e26762020-10-13 16:11:07 -07001393
Tai Lya4d748b2023-03-28 22:06:56 +00001394 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0,
1395 "OpMatMul: A zeropoint must be zero for non int8_t data");
1396 ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0,
1397 "OpMatMul: B zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001398
Eric Kunzee5e26762020-10-13 16:11:07 -07001399 return 0;
1400}
1401
Tai Lya4d748b2023-03-28 22:06:56 +00001402template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001403int OpMatMul<Dtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001404{
1405 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1406 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1407
1408 TIn a_val = this->a->getTensor();
1409 TIn b_val = this->b->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00001410 if (Dtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001411 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001412 a_val = a_val - (InEigenType)attribute->a_zp();
1413 b_val = b_val - (InEigenType)attribute->b_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001414 }
1415
Tai Ly307392a2023-05-12 21:42:19 +00001416 if (g_func_config.abs_mode)
1417 {
1418 // in abs_mode: take abs values of matmul operands
1419 a_val = a_val.abs();
1420 b_val = b_val.abs();
1421 }
1422
Kevin Cheng2d60f002021-06-09 14:18:32 -07001423 Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
1424 Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
1425 Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
1426
1427 Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
1428 Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
1429
1430 Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
1431 Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
1432
1433 // Iterate N dimension.
1434 for (int i = 0; i < N; i++)
1435 {
1436 a_begin_array[0] = i;
1437 b_begin_array[0] = i;
1438
1439 TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
1440 TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
1441 TAccRank2 output_rank2_val =
1442 a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
James Ward8b390432022-08-12 20:48:56 +01001443 TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast<OutEigenType>();
Kevin Cheng2d60f002021-06-09 14:18:32 -07001444 if (i == 0)
1445 {
1446 this->output->getTensor() = output_rank3_val;
1447 }
1448 else
1449 {
James Ward8b390432022-08-12 20:48:56 +01001450 TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001451 this->output->getTensor() = temp;
1452 }
1453 }
Eric Kunzee5e26762020-10-13 16:11:07 -07001454
Tai Lya4d748b2023-03-28 22:06:56 +00001455 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001456 {
James Ward8b390432022-08-12 20:48:56 +01001457 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1458 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001459 }
1460
1461 return GraphNode::eval();
1462}
1463
Tai Lya4d748b2023-03-28 22:06:56 +00001464template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001465OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001466 : GraphNode(sgt_, Op_MAX_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001467{
1468 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001469 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001470
Kevin Cheng93a16282021-08-31 16:14:03 -07001471 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -07001472}
1473
Tai Lya4d748b2023-03-28 22:06:56 +00001474template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001475OpMaxPool2d<Dtype>::~OpMaxPool2d()
1476{
1477 if (attribute)
1478 delete attribute;
1479}
1480
Tai Lya4d748b2023-03-28 22:06:56 +00001481template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001482int OpMaxPool2d<Dtype>::checkTensorAttributes()
1483{
1484 if (validateRequiredOperands())
1485 return 1;
1486
1487 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
1488 {
1489 return 1;
1490 }
1491
1492 if (inputs[0]->matchType(*outputs[0]))
1493 {
1494 printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
1495 return 1;
1496 }
1497
1498 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1499 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1500
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001501 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +00001502 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001503 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001504 msg = "OpMaxPool2d: " + msg;
1505 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001506 return 1;
1507 }
1508
1509 return 0;
1510}
1511
Tai Lya4d748b2023-03-28 22:06:56 +00001512template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -07001513int OpMaxPool2d<Dtype>::eval()
1514{
1515 int in_batch = this->in->getShape()[0];
1516 int in_height = this->in->getShape()[1];
1517 int in_width = this->in->getShape()[2];
1518 int in_channels = this->in->getShape()[3];
1519
1520 int out_batch = this->out->getShape()[0];
1521 int out_height = this->out->getShape()[1];
1522 int out_width = this->out->getShape()[2];
1523 int out_channels = this->out->getShape()[3];
1524
Kevin Chengacb550f2021-06-29 15:32:19 -07001525 ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1526 ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001527
TatWai Chong86c403b2022-06-06 20:46:01 -07001528 int pad_top = this->attribute->pad()[0];
1529 int pad_bottom = this->attribute->pad()[1];
1530 int pad_left = this->attribute->pad()[2];
1531 int pad_right = this->attribute->pad()[3];
1532
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001533 int kernel_y = this->attribute->kernel()[0];
1534 int kernel_x = this->attribute->kernel()[1];
1535 int stride_y = this->attribute->stride()[0];
1536 int stride_x = this->attribute->stride()[1];
Jerry Gea793f462023-04-11 00:05:02 +00001537
1538 // Check Tosa Level
1539 auto tosa_level = g_func_config.tosa_level;
1540 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
1541 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
1542 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1543 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
1544 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1545 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1546 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1547 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -07001548
1549 DEBUG_INFO(OP,
1550 "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
TatWai Chong86c403b2022-06-06 20:46:01 -07001551 "stride=[%d,%d], pad=[%d,%d,%d,%d]",
Jerry Gea793f462023-04-11 00:05:02 +00001552 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
1553 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001554
1555 Eigen::array<Eigen::Index, 2> im2col_input_dims;
Jerry Gea793f462023-04-11 00:05:02 +00001556 im2col_input_dims[0] = kernel_y * kernel_x;
Eric Kunzee5e26762020-10-13 16:11:07 -07001557 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
1558
1559 Eigen::array<Eigen::Index, 4> col2im_output_dims;
1560 col2im_output_dims[0] = out_batch;
1561 col2im_output_dims[1] = out_height;
1562 col2im_output_dims[2] = out_width;
1563 col2im_output_dims[3] = out_channels;
1564
TatWai Chong86c403b2022-06-06 20:46:01 -07001565 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1566 pad[0] = std::make_pair(0, 0);
1567 pad[1] = std::make_pair(pad_top, pad_bottom);
1568 pad[2] = std::make_pair(pad_left, pad_right);
1569 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001570
TatWai Chong86c403b2022-06-06 20:46:01 -07001571 ETensor4<InEigenType> input_padded = this->in->getTensor().pad(pad, std::numeric_limits<InEigenType>::lowest());
Eric Kunzee5e26762020-10-13 16:11:07 -07001572
1573 // extract_image_patches() output [N, KH, KW, H * W, C]
1574 // transpose to [KH, KW, N, H * W, C]
1575 // reshape to [KH * KW, N * H * W * C]
1576 //
1577 // Set the padding value to be the most negative value that can be
1578 // represented by the datatype to ensure that any padding values will be equal
1579 // to or smaller than the actual maximum in the KH x KW patch.
1580 ETensor2<InEigenType> input_extract_patches =
1581 input_padded
Jerry Gea793f462023-04-11 00:05:02 +00001582 .extract_image_patches(kernel_y, kernel_x, stride_y, stride_x, 1, 1, Eigen::PADDING_VALID,
Eric Kunzee5e26762020-10-13 16:11:07 -07001583 std::numeric_limits<InEigenType>::lowest())
1584 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
1585 .reshape(im2col_input_dims);
1586
1587 // Get the maximum of the KHxHW patches along axis 0
1588 Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
1589
1590 // 1D result with [N * H * W * C]
1591 ETensor1<OutEigenType> out_1d(this->out->getElementCount());
1592
1593 // index input_patches with argmax array should give the result
1594 for (size_t i = 0; i < this->out->getElementCount(); i++)
1595 {
1596 out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
1597 }
1598
1599 // reshape result to [N, H, W, C]
1600 this->out->getTensor() = out_1d.reshape(col2im_output_dims);
1601
1602 return GraphNode::eval();
1603}
1604
Tai Lya4d748b2023-03-28 22:06:56 +00001605template <TOSA_REF_TYPE Dtype>
1606OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton57287132023-02-06 14:54:18 +00001607 : GraphNode(sgt_, Op_FFT2D, id_)
1608{
1609 setRequiredOperands(2, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001610 setRequiredRank(3, 3);
Luke Hutton57287132023-02-06 14:54:18 +00001611
1612 INIT_ATTRIBUTE(FFT);
1613}
1614
Tai Lya4d748b2023-03-28 22:06:56 +00001615template <TOSA_REF_TYPE Dtype>
1616OpFFT2d<Dtype>::~OpFFT2d()
1617{
Luke Hutton57287132023-02-06 14:54:18 +00001618 if (attribute)
1619 delete attribute;
1620}
1621
Tai Lya4d748b2023-03-28 22:06:56 +00001622template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001623int OpFFT2d<Dtype>::checkTensorAttributes()
1624{
1625 if (validateRequiredOperands())
1626 return 1;
1627
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001628 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]) ||
1629 validateRequiredRank(outputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001630 {
1631 return 1;
1632 }
1633
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001634 if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) || inputs[0]->matchType(*inputs[1]))
Luke Hutton57287132023-02-06 14:54:18 +00001635 {
1636 printNodeValidationError("OpFFT2d: input and output tensor type mismatch");
1637 return 1;
1638 }
1639
1640 in_real = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1641 in_imag = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
1642 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1643 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1644
1645 ASSERT_MEM(in_real && in_imag && out_real && out_imag);
1646
1647 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001648 if (check_fft_shape(in_real->getShape(), in_imag->getShape(), out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton57287132023-02-06 14:54:18 +00001649 {
1650 msg = "OpFFT2d: " + msg;
1651 printNodeValidationError(msg.c_str());
1652 return 1;
1653 }
1654
1655 return 0;
1656}
1657
Tai Lya4d748b2023-03-28 22:06:56 +00001658template <TOSA_REF_TYPE Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001659int OpFFT2d<Dtype>::eval()
1660{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001661 int in_real_batch = this->in_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001662 int in_real_height = this->in_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001663 int in_real_width = this->in_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001664
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001665 int in_imag_batch = this->in_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001666 int in_imag_height = this->in_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001667 int in_imag_width = this->in_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001668
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001669 int out_real_batch = this->out_real->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001670 int out_real_height = this->out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001671 int out_real_width = this->out_real->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001672
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001673 int out_imag_batch = this->out_imag->getShape()[0];
Luke Hutton57287132023-02-06 14:54:18 +00001674 int out_imag_height = this->out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001675 int out_imag_width = this->out_imag->getShape()[2];
Luke Hutton57287132023-02-06 14:54:18 +00001676
Jerry Gea793f462023-04-11 00:05:02 +00001677 // Check Tosa Level
1678 auto tosa_level = g_func_config.tosa_level;
1679 LEVEL_CHECK(in_real_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1680 LEVEL_CHECK(in_real_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1681
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001682 DEBUG_INFO(OP, "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]",
1683 in_real_batch, in_real_height, in_real_width, in_imag_batch, in_imag_height, in_imag_width,
1684 out_real_batch, out_real_height, out_real_width, out_imag_batch, out_imag_height, out_imag_width);
Luke Hutton57287132023-02-06 14:54:18 +00001685
1686 OutEigenType sum_real, sum_imag, a, sign_val = 1.0;
1687
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001688 if (attribute->inverse())
1689 {
Luke Hutton57287132023-02-06 14:54:18 +00001690 sign_val = -1.0;
1691 }
1692
Tai Ly307392a2023-05-12 21:42:19 +00001693 TIn in_real_val = this->in_real->getTensor();
1694 TIn in_imag_val = this->in_imag->getTensor();
1695
1696 if (g_func_config.abs_mode)
1697 {
1698 // in abs_mode: take abs values of real and imag operands
1699 in_real_val = in_real_val.abs();
1700 in_imag_val = in_imag_val.abs();
1701 }
1702
Luke Hutton57287132023-02-06 14:54:18 +00001703 for (int n = 0; n < in_real_batch; n++)
1704 {
1705 for (int oy = 0; oy < out_real_height; oy++)
1706 {
1707 for (int ox = 0; ox < out_real_width; ox++)
1708 {
1709 sum_real = 0.0;
1710 sum_imag = 0.0;
1711 for (int iy = 0; iy < in_real_height; iy++)
1712 {
1713 for (int ix = 0; ix < in_real_width; ix++)
1714 {
Tai Ly307392a2023-05-12 21:42:19 +00001715 OutEigenType val_real = in_real_val(n, iy, ix);
1716 OutEigenType val_imag = in_imag_val(n, iy, ix);
Luke Hutton57287132023-02-06 14:54:18 +00001717 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001718 a = sign_val * 2 * M_PI *
1719 ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width);
Luke Hutton57287132023-02-06 14:54:18 +00001720 sum_real += val_real * cos(a) + val_imag * sin(a);
1721 sum_imag += -val_real * sin(a) + val_imag * cos(a);
1722 }
1723 }
1724 this->out_real->getTensor()(n, oy, ox) = sum_real;
1725 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1726 }
1727 }
1728 }
1729
1730 return GraphNode::eval();
1731}
1732
Tai Lya4d748b2023-03-28 22:06:56 +00001733template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001734OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Luke Hutton261b7b62023-01-10 14:50:31 +00001735 : GraphNode(sgt_, Op_RFFT2D, id_)
1736{
1737 setRequiredOperands(1, 2);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001738 setRequiredRank(3, 3);
Luke Hutton261b7b62023-01-10 14:50:31 +00001739}
1740
Tai Lya4d748b2023-03-28 22:06:56 +00001741template <TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001742OpRFFT2d<Dtype>::~OpRFFT2d()
1743{}
Luke Hutton261b7b62023-01-10 14:50:31 +00001744
Tai Lya4d748b2023-03-28 22:06:56 +00001745template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001746int OpRFFT2d<Dtype>::checkTensorAttributes()
1747{
1748 if (validateRequiredOperands())
1749 return 1;
1750
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001751 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1]))
Luke Hutton261b7b62023-01-10 14:50:31 +00001752 {
1753 return 1;
1754 }
1755
1756 if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
1757 {
1758 printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
1759 return 1;
1760 }
1761
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001762 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
Luke Hutton261b7b62023-01-10 14:50:31 +00001763 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1764 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1765
1766 ASSERT_MEM(in && out_real && out_imag);
1767
Luke Hutton57287132023-02-06 14:54:18 +00001768 std::string msg;
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001769 if (check_fft_shape(in->getShape(), {}, out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton261b7b62023-01-10 14:50:31 +00001770 {
Luke Hutton57287132023-02-06 14:54:18 +00001771 msg = "OpRFFT2d: " + msg;
1772 printNodeValidationError(msg.c_str());
Luke Hutton261b7b62023-01-10 14:50:31 +00001773 return 1;
1774 }
1775
1776 return 0;
1777}
1778
Tai Lya4d748b2023-03-28 22:06:56 +00001779template <TOSA_REF_TYPE Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001780int OpRFFT2d<Dtype>::eval()
1781{
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001782 int32_t in_batch = in->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001783 int32_t in_height = in->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001784 int32_t in_width = in->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001785
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001786 int32_t out_real_batch = out_real->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001787 int32_t out_real_height = out_real->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001788 int32_t out_real_width = out_real->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001789
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001790 int32_t out_imag_batch = out_imag->getShape()[0];
Luke Hutton261b7b62023-01-10 14:50:31 +00001791 int32_t out_imag_height = out_imag->getShape()[1];
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001792 int32_t out_imag_width = out_imag->getShape()[2];
Luke Hutton261b7b62023-01-10 14:50:31 +00001793
Jerry Gea793f462023-04-11 00:05:02 +00001794 // Check Tosa Level
1795 auto tosa_level = g_func_config.tosa_level;
1796 LEVEL_CHECK(in_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1797 LEVEL_CHECK(in_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1798
Luke Hutton261b7b62023-01-10 14:50:31 +00001799 DEBUG_INFO(OP,
1800 "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], "
1801 "output_imag.shape=[%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001802 in_batch, in_height, in_width, out_real_batch, out_real_height, out_real_width, out_imag_batch,
1803 out_imag_height, out_imag_width);
Luke Hutton261b7b62023-01-10 14:50:31 +00001804
1805 OutEigenType sum_real, sum_imag, a;
1806
Tai Ly307392a2023-05-12 21:42:19 +00001807 TIn in_val = this->in->getTensor();
1808
1809 if (g_func_config.abs_mode)
1810 {
1811 // in abs_mode: take abs values of in operand
1812 in_val = in_val.abs();
1813 }
1814
Luke Hutton261b7b62023-01-10 14:50:31 +00001815 for (int n = 0; n < in_batch; n++)
1816 {
1817 for (int oy = 0; oy < out_real_height; oy++)
1818 {
1819 for (int ox = 0; ox < out_real_width; ox++)
1820 {
1821 sum_real = 0.0;
1822 sum_imag = 0.0;
1823 for (int iy = 0; iy < in_height; iy++)
1824 {
1825 for (int ix = 0; ix < in_width; ix++)
1826 {
1827 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
1828 a = 2 * M_PI * ((iy * (OutEigenType)oy) / in_height + (ix * (OutEigenType)ox) / in_width);
Tai Ly307392a2023-05-12 21:42:19 +00001829 sum_real += in_val(n, iy, ix) * cos(a);
1830 sum_imag += -in_val(n, iy, ix) * sin(a);
Luke Hutton261b7b62023-01-10 14:50:31 +00001831 }
1832 }
1833 this->out_real->getTensor()(n, oy, ox) = sum_real;
1834 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1835 }
1836 }
1837 }
1838
1839 return GraphNode::eval();
1840}
1841
Tai Lya4d748b2023-03-28 22:06:56 +00001842template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001843OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
Tai Lya4d748b2023-03-28 22:06:56 +00001844 TosaAttributeBase* attribute_,
1845 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001846 : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001847{
1848 setRequiredOperands(3, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +00001849 setRequiredRank(4, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -07001850
Kevin Cheng93a16282021-08-31 16:14:03 -07001851 INIT_ATTRIBUTE(TransposeConv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001852}
1853
Tai Lya4d748b2023-03-28 22:06:56 +00001854template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001855OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001856{
1857 if (attribute)
1858 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001859}
1860
Tai Lya4d748b2023-03-28 22:06:56 +00001861template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001862int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001863{
1864 if (validateRequiredOperands())
1865 return 1;
1866
1867 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1868 {
1869 return 1;
1870 }
1871
James Wardd34b3fc2023-01-18 14:51:25 +00001872 ERROR_IF(outputs[0]->getDtype() != OutDtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001873 "OpTransposeConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001874
Eric Kunzee5e26762020-10-13 16:11:07 -07001875 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1876 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1877 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001878 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001879
TatWai Chong24594f52022-06-08 00:48:04 -07001880 if (attribute->out_pad().size() != 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001881 {
TatWai Chong24594f52022-06-08 00:48:04 -07001882 printNodeValidationError("OpTransposeConv2d: illegal size for attribute out_pad");
Eric Kunzee5e26762020-10-13 16:11:07 -07001883 return 1;
1884 }
1885
1886 if (attribute->stride().size() != 2)
1887 {
1888 printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
1889 return 1;
1890 }
1891
Eric Kunzee5e26762020-10-13 16:11:07 -07001892 if (attribute->output_shape().size() != 4)
1893 {
1894 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1895 return 1;
1896 }
1897
Kevin Cheng9fe17242021-11-10 01:04:39 +00001898 for (int32_t i : attribute->stride())
1899 {
1900 if (i < 1)
1901 {
1902 printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one");
1903 return 1;
1904 }
1905 }
1906
Eric Kunzee5e26762020-10-13 16:11:07 -07001907 for (int d = 0; d < 4; d++)
1908 {
1909 if (attribute->output_shape()[d] != this->output->getShape()[d])
1910 {
1911 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1912 return 1;
1913 }
1914 }
1915
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001916 int32_t IH = input->getShape()[1];
1917 int32_t IW = input->getShape()[2];
1918 int32_t OH = output->getShape()[1];
1919 int32_t OW = output->getShape()[2];
1920
1921 int32_t stride_y = attribute->stride()[0];
1922 int32_t stride_x = attribute->stride()[1];
1923 int32_t kernel_h = weight->getShape()[1];
1924 int32_t kernel_w = weight->getShape()[2];
1925
TatWai Chong24594f52022-06-08 00:48:04 -07001926 int32_t out_pad_top = attribute->out_pad()[0];
1927 int32_t out_pad_bottom = attribute->out_pad()[1];
1928 int32_t out_pad_left = attribute->out_pad()[2];
1929 int32_t out_pad_right = attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001930
Eric Kunzec1a97832022-07-01 16:56:09 -07001931 for (size_t i = 0; i < attribute->out_pad().size(); i++)
1932 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001933 ERROR_IF(attribute->out_pad()[i] <= -(weight->getShape()[(i / 2) + 1]),
1934 "OpTransposeConv2d: At least one out_pad value is larger than kernel size");
Eric Kunzec1a97832022-07-01 16:56:09 -07001935 }
1936
1937 int32_t H = (IH - 1) * stride_y + out_pad_top + out_pad_bottom + kernel_h;
1938 int32_t W = (IW - 1) * stride_x + out_pad_left + out_pad_right + kernel_w;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001939
1940 if ((OH != H) || (OW != W))
1941 {
1942 std::string msg = "OpTransposeConv2d: Mismatch between output shape provided and expected output shape (" +
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001943 std::to_string(H) + "," + std::to_string(W) + ")";
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001944 printNodeValidationError(msg.c_str());
1945 return 1;
1946 }
1947
Tai Lya4d748b2023-03-28 22:06:56 +00001948 ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
1949 "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
1950 ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
1951 "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001952
Eric Kunzee5e26762020-10-13 16:11:07 -07001953 return 0;
1954}
1955
Tai Lya4d748b2023-03-28 22:06:56 +00001956template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
James Wardd34b3fc2023-01-18 14:51:25 +00001957int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001958{
1959 int in_batch = this->input->getShape()[0];
1960 int in_height = this->input->getShape()[1];
1961 int in_width = this->input->getShape()[2];
1962 int in_channels = this->input->getShape()[3];
1963
1964 int f_out_channels = this->weight->getShape()[0];
1965 int f_height = this->weight->getShape()[1];
1966 int f_width = this->weight->getShape()[2];
1967 int f_in_channels = this->weight->getShape()[3];
1968
1969 int b_out_channels = this->bias->getShape()[0];
1970
1971 int out_batch = this->output->getShape()[0];
1972 int out_height = this->output->getShape()[1];
1973 int out_width = this->output->getShape()[2];
1974 int out_channels = this->output->getShape()[3];
1975
TatWai Chong24594f52022-06-08 00:48:04 -07001976 int out_pad_top = this->attribute->out_pad()[0];
1977 int out_pad_bottom = this->attribute->out_pad()[1];
1978 int out_pad_left = this->attribute->out_pad()[2];
1979 int out_pad_right = this->attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001980
Jerry Gea793f462023-04-11 00:05:02 +00001981 int stride_y = this->attribute->stride()[0];
1982 int stride_x = this->attribute->stride()[1];
Eric Kunzee5e26762020-10-13 16:11:07 -07001983
Kevin Chengacb550f2021-06-29 15:32:19 -07001984 ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1985 ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1986 in_channels);
1987 ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
1988 f_out_channels, out_channels);
Tai Lya641dd52023-08-11 19:58:50 +00001989 ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
1990 "OpTransposeConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001991
Jerry Gea793f462023-04-11 00:05:02 +00001992 // Check Tosa Level
1993 auto tosa_level = g_func_config.tosa_level;
1994 LEVEL_CHECK(f_height <= tosa_level.MAX_KERNEL, "KH should be smaller than or equal to MAX_KERNEL");
1995 LEVEL_CHECK(f_width <= tosa_level.MAX_KERNEL, "KW should be smaller than or equal to MAX_KERNEL");
1996 LEVEL_CHECK(out_pad_top <= tosa_level.MAX_KERNEL, "out_pad_top should be smaller than or equal to MAX_KERNEL");
Jerry Ge9c9c8da2023-07-19 23:08:16 +00001997 LEVEL_CHECK(out_pad_bottom <= tosa_level.MAX_KERNEL,
1998 "out_pad_bottom should be smaller than or equal to MAX_KERNEL");
Jerry Gea793f462023-04-11 00:05:02 +00001999 LEVEL_CHECK(out_pad_left <= tosa_level.MAX_KERNEL, "out_pad_left should be smaller than or equal to MAX_KERNEL");
2000 LEVEL_CHECK(out_pad_right <= tosa_level.MAX_KERNEL, "out_pad_right should be smaller than or equal to MAX_KERNEL");
2001 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
2002 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
2003
Eric Kunzee5e26762020-10-13 16:11:07 -07002004 DEBUG_INFO(OP,
2005 "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00002006 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002007 in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch,
2008 out_height, out_width, out_channels, stride_y, stride_x, out_pad_top, out_pad_bottom, out_pad_left,
2009 out_pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07002010
2011 TIn input_val = this->input->getTensor();
2012 TWeight weight_val = this->weight->getTensor();
Tai Lya4d748b2023-03-28 22:06:56 +00002013 if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002014 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002015 input_val = input_val - (InEigenType)attribute->input_zp();
2016 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07002017 }
2018
Tai Ly307392a2023-05-12 21:42:19 +00002019 TBias bias_val = this->bias->getTensor();
2020
2021 if (g_func_config.abs_mode)
2022 {
2023 // in abs_mode: take abs values of conv operands
2024 input_val = input_val.abs();
2025 weight_val = weight_val.abs();
2026 bias_val = bias_val.abs();
2027 }
2028
Eric Kunzee5e26762020-10-13 16:11:07 -07002029 Eigen::array<Eigen::Index, 4> reshape_dim;
2030 reshape_dim.fill(1);
2031 reshape_dim[3] = b_out_channels;
2032
2033 Eigen::array<Eigen::Index, 4> bcast;
2034 bcast[0] = out_batch;
2035 bcast[1] = out_height;
2036 bcast[2] = out_width;
Tai Lya641dd52023-08-11 19:58:50 +00002037 bcast[3] = (b_out_channels == 1) ? out_channels : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -07002038
2039 // initialize with bias
Tai Ly307392a2023-05-12 21:42:19 +00002040 this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07002041
2042 int out_x_origin, out_y_origin;
2043 int out_x, out_y;
2044
2045 // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
2046 for (int ob = 0; ob < out_batch; ob++)
2047 {
2048 for (int ih = 0; ih < in_height; ih++)
2049 {
2050 for (int iw = 0; iw < in_width; iw++)
2051 {
Jerry Gea793f462023-04-11 00:05:02 +00002052 out_x_origin = iw * stride_x + out_pad_left;
2053 out_y_origin = ih * stride_y + out_pad_top;
Eric Kunzee5e26762020-10-13 16:11:07 -07002054 for (int ic = 0; ic < in_channels; ic++)
2055 {
2056 for (int fh = 0; fh < f_height; fh++)
2057 {
2058 for (int fw = 0; fw < f_width; fw++)
2059 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002060 out_x = out_x_origin + fw;
2061 out_y = out_y_origin + fh;
Eric Kunzee5e26762020-10-13 16:11:07 -07002062 for (int oc = 0; oc < out_channels; oc++)
2063 {
2064 if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
2065 {
2066 this->output->getTensor()(ob, out_y, out_x, oc) +=
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002067 (OutEigenType)((AccEigenType)input_val(ob, ih, iw, ic) *
2068 (AccEigenType)weight_val(oc, fh, fw, ic));
Eric Kunzee5e26762020-10-13 16:11:07 -07002069 }
2070 }
2071 }
2072 }
2073 }
2074 }
2075 }
2076 }
2077
Tai Lya4d748b2023-03-28 22:06:56 +00002078 if (OutDtype == TOSA_REF_TYPE_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07002079 {
James Ward8b390432022-08-12 20:48:56 +01002080 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
2081 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07002082 }
2083
2084 return GraphNode::eval();
2085}
2086
2087// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +01002088DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002089DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002090DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002091DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002092DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002093DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002094
James Wardd34b3fc2023-01-18 14:51:25 +00002095DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
2096DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
2097DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
2098DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
2099DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
2100DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +00002101DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002102
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002103// [in_t, weight_t, out_t]
James Wardd34b3fc2023-01-18 14:51:25 +00002104DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
2105DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
2106DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
2107DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
2108DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
2109DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
2110DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002111DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002112
James Wardd34b3fc2023-01-18 14:51:25 +00002113DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
2114DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
2115DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
2116DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
2117DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
2118DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
2119DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002120DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
Kevin Cheng1533b852021-09-01 12:51:58 -07002121
James Wardd34b3fc2023-01-18 14:51:25 +00002122DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
2123DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
2124DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
2125DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
2126DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
2127DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
2128DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002129DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002130
Luke Hutton57287132023-02-06 14:54:18 +00002131DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002132DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
Luke Hutton57287132023-02-06 14:54:18 +00002133
James Wardd34b3fc2023-01-18 14:51:25 +00002134DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
2135DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
2136DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
2137DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
2138DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
2139DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
2140DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002141DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002142
James Wardd34b3fc2023-01-18 14:51:25 +00002143DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
2144DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
2145DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
2146DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
2147DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
2148DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002149DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002150
James Ward8b390432022-08-12 20:48:56 +01002151DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002152DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002153DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002154DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002155DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
Tai Lya4d748b2023-03-28 22:06:56 +00002156DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -07002157
Luke Hutton261b7b62023-01-10 14:50:31 +00002158DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +00002159DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
Luke Hutton261b7b62023-01-10 14:50:31 +00002160
James Wardd34b3fc2023-01-18 14:51:25 +00002161DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
2162DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
2163DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
2164DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
2165DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
2166DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
2167DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
Tai Lya4d748b2023-03-28 22:06:56 +00002168DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);