blob: af808e83fb5f6ef6bd1364df39e5f101780c21b7 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Hutton261b7b62023-01-10 14:50:31 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tensor_ops.h"
17#include "quant_util.h"
18#include "template_types.h"
James Ward8b390432022-08-12 20:48:56 +010019#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070020
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
Kevin Cheng9fe17242021-11-10 01:04:39 +000025int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
26 std::vector<int32_t> input_shape,
27 std::vector<int32_t> output_shape,
28 std::string& msg)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000029{
TatWai Chong86c403b2022-06-06 20:46:01 -070030 if (attribute->pad().size() != 4)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000031 {
32 msg = "illegal size for attribute padding";
33 return 1;
34 }
35
36 if (attribute->kernel().size() != 2)
37 {
38 msg = "illegal size for attribute kernel";
39 return 1;
40 }
41
42 if (attribute->stride().size() != 2)
43 {
44 msg = "illegal size for attribute stride";
45 return 1;
46 }
47
TatWai Chong86c403b2022-06-06 20:46:01 -070048 for (int32_t i : attribute->pad())
Kevin Cheng7eb93d72021-10-09 01:26:08 +000049 {
50 if (i < 0)
51 {
52 msg = "At least one pad is smaller than zero";
53 return 1;
54 }
55 }
56
57 for (int32_t i : attribute->kernel())
58 {
59 if (i < 1)
60 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000061 msg = "At least one kernel dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000062 return 1;
63 }
64 }
65
66 for (int32_t i : attribute->stride())
67 {
68 if (i < 1)
69 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000070 msg = "At least one stride dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000071 return 1;
72 }
73 }
74
75 int32_t IH = input_shape[1];
76 int32_t IW = input_shape[2];
77 int32_t OH = output_shape[1];
78 int32_t OW = output_shape[2];
79
TatWai Chong86c403b2022-06-06 20:46:01 -070080 int32_t pad_top = attribute->pad()[0];
81 int32_t pad_bottom = attribute->pad()[1];
82 int32_t pad_left = attribute->pad()[2];
83 int32_t pad_right = attribute->pad()[3];
Kevin Cheng7eb93d72021-10-09 01:26:08 +000084
85 int32_t stride_y = attribute->stride()[0];
86 int32_t stride_x = attribute->stride()[1];
87 int32_t kernel_y = attribute->kernel()[0];
88 int32_t kernel_x = attribute->kernel()[1];
89
90 if (pad_top >= kernel_y || pad_bottom >= kernel_y || pad_left >= kernel_x || pad_right >= kernel_x)
91 {
92 msg = "At least one pad is >= kernel dimension";
93 return 1;
94 }
95
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010096 int32_t full_H = IH + pad_top + pad_bottom - kernel_y;
97 int32_t full_W = IW + pad_left + pad_right - kernel_x;
98
99 if ((full_H % stride_y != 0) ||
100 (full_W % stride_x != 0))
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000101 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100102 msg = "Parameters must yield exact integer output dimensions";
103 return 1;
104 }
105
106 if ((OH != (full_H / stride_y) + 1) ||
107 (OW != (full_W / stride_x) + 1))
108 {
109 msg = "Mismatch between output shape provided and expected output shape (" +
110 std::to_string((full_H / stride_y) + 1) + "," +
111 std::to_string((full_W / stride_x) + 1) + ")";
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000112 return 1;
113 }
114
115 return 0;
116}
117
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000118int check_conv_attribute(tosa::TosaConvAttribute* attribute,
Kevin Cheng9fe17242021-11-10 01:04:39 +0000119 uint32_t conv_dimension,
120 std::vector<int32_t> input_shape,
121 std::vector<int32_t> output_shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100122 std::vector<int32_t> weights,
123 uint32_t offset_kernel,
Kevin Cheng9fe17242021-11-10 01:04:39 +0000124 DType InDtype,
125 DType WeightDtype,
126 std::string& msg)
127{
TatWai Chong86c403b2022-06-06 20:46:01 -0700128 if (attribute->pad().size() != (2 * conv_dimension))
Kevin Cheng9fe17242021-11-10 01:04:39 +0000129 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700130 msg = "Illegal size for attribute pad";
Kevin Cheng9fe17242021-11-10 01:04:39 +0000131 return 1;
132 }
133
134 if (attribute->stride().size() != conv_dimension)
135 {
136 msg = "Illegal size for attribute stride";
137 return 1;
138 }
139
140 if (attribute->dilation().size() != conv_dimension)
141 {
142 msg = "Illegal size for attribute dilation";
143 return 1;
144 }
145
TatWai Chong86c403b2022-06-06 20:46:01 -0700146 for (int32_t i : attribute->pad())
Kevin Cheng9fe17242021-11-10 01:04:39 +0000147 {
148 if (i < 0)
149 {
150 msg = "At least one pad is smaller than zero";
151 return 1;
152 }
153 }
154
155 for (int32_t i : attribute->stride())
156 {
157 if (i < 1)
158 {
159 msg = "At least one stride dimension is smaller than one";
160 return 1;
161 }
162 }
163
164 for (int32_t i : attribute->dilation())
165 {
166 if (i < 1)
167 {
168 msg = "At least one dilation dimension is smaller than one";
169 return 1;
170 }
171 }
172
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100173 ASSERT_MSG(conv_dimension == 2 || conv_dimension == 3, "Unsupported convolution dimension")
174
TatWai Chongfd629052022-07-25 04:01:58 +0000175 int32_t offset_d = conv_dimension == 3 ? 1 : 0;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100176 int32_t ID = conv_dimension == 3 ? input_shape[1] : 1;
177 int32_t IH = input_shape[1 + offset_d];
178 int32_t IW = input_shape[2 + offset_d];
179 int32_t OD = conv_dimension == 3 ? output_shape[1] : 1;
180 int32_t OH = output_shape[1 + offset_d];
181 int32_t OW = output_shape[2 + offset_d];
182
183 int32_t stride_d = conv_dimension == 3 ? attribute->stride()[0] : 1;
184 int32_t stride_y = attribute->stride()[0 + offset_d];
185 int32_t stride_x = attribute->stride()[1 + offset_d];
186 int32_t kernel_d = conv_dimension == 3 ? weights[offset_kernel] : 1;
187 int32_t kernel_h = weights[offset_kernel + offset_d];
188 int32_t kernel_w = weights[offset_kernel + 1 + offset_d];
189 int32_t dilation_d = conv_dimension == 3 ? attribute->dilation()[0] : 1;
190 int32_t dilation_y = attribute->dilation()[0 + offset_d];
191 int32_t dilation_x = attribute->dilation()[1 + offset_d];
192
193 offset_d *= 2;
TatWai Chong86c403b2022-06-06 20:46:01 -0700194 int32_t pad_d0 = conv_dimension == 3 ? attribute->pad()[0] : 0;
195 int32_t pad_d1 = conv_dimension == 3 ? attribute->pad()[1] : 0;
196 int32_t pad_top = attribute->pad()[0 + offset_d];
197 int32_t pad_bottom = attribute->pad()[1 + offset_d];
198 int32_t pad_left = attribute->pad()[2 + offset_d];
199 int32_t pad_right = attribute->pad()[3 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100200
201 int32_t full_D = ID - 1 + pad_d0 + pad_d1 - (kernel_d - 1) * dilation_d;
202 int32_t full_H = IH - 1 + pad_top + pad_bottom - (kernel_h - 1) * dilation_y;
203 int32_t full_W = IW - 1 + pad_left + pad_right - (kernel_w - 1) * dilation_x;
204
205 if ((full_H % stride_y != 0) ||
206 (full_W % stride_x != 0) ||
207 (full_D % stride_d != 0))
208 {
209 msg = "Parameters must yield exact integer output dimensions";
210 return 1;
211 }
212
213 if ((OH != (full_H / stride_y) + 1) ||
214 (OW != (full_W / stride_x) + 1) ||
215 (OD != (full_D / stride_d) + 1))
216 {
217 std::string msg_d = "";
218 if (conv_dimension == 3)
219 {
220 msg_d += std::to_string((full_D / stride_d) + 1) + ",";
221 }
222 msg = "Mismatch between output shape provided and expected output shape (" +
223 msg_d +
224 std::to_string((full_H / stride_y) + 1) + "," +
225 std::to_string((full_W / stride_x) + 1) + ")";
226 return 1;
227 }
228
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000229 if (InDtype != DType_INT8 && attribute->input_zp() != 0) {
230 msg = "Input zero point must be zero for non-int8 data";
231 return 1;
232 }
233 if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) {
234 msg = "Weight zero point must be zero for non-int8 data";
235 return 1;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000236 }
237
238 return 0;
239}
240
Luke Hutton57287132023-02-06 14:54:18 +0000241int check_fft_shape(const std::vector<int32_t>& in_real,
242 const std::vector<int32_t>& in_imag,
243 const std::vector<int32_t>& out_real,
244 const std::vector<int32_t>& out_imag,
245 std::string& msg) {
246 const bool is_rfft = in_imag.empty();
247 auto is_power_of_two = [](int32_t n) -> bool
248 {
249 return (n & (n-1)) == 0 && n > 0;
250 };
251
252 if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2]))
253 {
254 msg = "Input height and width must be a power of two";
255 return 1;
256 }
257
258 // RFFT does not have a second input
259 if (!is_rfft)
260 {
261 bool input_check = true;
262 for (size_t i = 0; i < in_real.size(); i++)
263 {
264 if (in_real[i] != in_imag[i])
265 {
266 input_check = false;
267 break;
268 }
269 }
270 if (!input_check)
271 {
272 msg = "Mismatch between real input shape and imaginary input shape";
273 return 1;
274 }
275 }
276
277 bool output_check = true;
278 for (size_t i = 0; i < out_real.size(); i++)
279 {
280 if (out_real[i] != out_imag[i])
281 {
282 output_check = false;
283 break;
284 }
285 }
286 if (!output_check)
287 {
288 msg = "Mismatch between real output shape and imaginary output shape";
289 return 1;
290 }
291
292 if (in_real[0] != out_real[0])
293 {
294 msg = "Input and output batch size don't match";
295 return 1;
296 }
297 if (in_real[1] != out_real[1])
298 {
299 msg = "Input and output height don't match";
300 return 1;
301 }
302
303 if (is_rfft)
304 {
305 if (in_real[2] / 2 + 1 != out_real[2])
306 {
307 msg = "Output width is expected to match input width / 2 + 1";
308 return 1;
309 }
310 } else {
311 if (in_real[2] != out_real[2])
312 {
313 msg = "Input and output width don't match";
314 return 1;
315 }
316 }
317
318 return 0;
319}
320
Eric Kunzee5e26762020-10-13 16:11:07 -0700321template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700322OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
323 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700324 uint64_t id_)
325 : GraphNode(sgt_, Op_ARGMAX, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700326{
327 setRequiredOperands(1, 1);
Kevin Chengcc61be32021-10-14 17:09:57 -0700328 setRequiredRank(1, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700329
330 INIT_ATTRIBUTE(Axis);
331}
332
333template <int Rank, DType Dtype>
334OpArgMax<Rank, Dtype>::~OpArgMax()
335{
336 if (attribute)
337 delete attribute;
338}
339
340template <int Rank, DType Dtype>
341int OpArgMax<Rank, Dtype>::checkTensorAttributes()
342{
343 if (validateRequiredOperands())
344 return 1;
345
Kevin Chengcc61be32021-10-14 17:09:57 -0700346 if (validateRequiredRank(inputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700347 {
348 return 1;
349 }
350
Kevin Chengcc61be32021-10-14 17:09:57 -0700351 int32_t output_rank = inputs[0]->getRank() - 1;
352 if (output_rank != outputs[0]->getRank())
353 {
354 printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
355 return 1;
356 }
357
358 if (outputs[0]->getDtype() != DType_INT32)
359 {
360 printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
361 return 1;
362 }
363
Eric Kunzee5e26762020-10-13 16:11:07 -0700364 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
365 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
366
Kevin Chengcc61be32021-10-14 17:09:57 -0700367 if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
368 {
369 printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
370 return 1;
371 }
372
373 bool shape_check = true;
374 for (int32_t i = 0; i < input->getRank(); i++)
375 {
376 if (i < attribute->axis())
377 {
378 if (input->getShape()[i] != output->getShape()[i])
379 {
380 shape_check = false;
381 break;
382 }
383 }
384 else if (i > attribute->axis())
385 {
386 if (input->getShape()[i] != output->getShape()[i - 1])
387 {
388 shape_check = false;
389 break;
390 }
391 }
392 // No need to check i == axis
393 }
394 if (!shape_check)
395 {
396 printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
397 return 1;
398 }
399
Eric Kunzee5e26762020-10-13 16:11:07 -0700400 return 0;
401}
402
403template <int Rank, DType Dtype>
404int OpArgMax<Rank, Dtype>::eval()
405{
406 Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
407
408 this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
409
410 return GraphNode::eval();
411}
412
James Ward8b390432022-08-12 20:48:56 +0100413template <DType Dtype, DType AccDtype>
414OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700415 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700416 uint64_t id_)
417 : GraphNode(sgt_, Op_AVG_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700418{
419 setRequiredOperands(1, 1);
420 setRequiredRank(4);
421
Kevin Cheng93a16282021-08-31 16:14:03 -0700422 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -0700423}
424
James Ward8b390432022-08-12 20:48:56 +0100425template <DType Dtype, DType AccDtype>
426OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700427{
428 if (attribute)
429 delete attribute;
430}
431
James Ward8b390432022-08-12 20:48:56 +0100432template <DType Dtype, DType AccDtype>
433int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700434{
435 if (validateRequiredOperands())
436 return 1;
437
438 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
439 {
440 return 1;
441 }
442
443 if (inputs[0]->matchType(*outputs[0]))
444 {
445 printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
446 return 1;
447 }
448
449 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
450 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
451
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000452 ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
453 ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
Eric Kunzee5e26762020-10-13 16:11:07 -0700454
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000455 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000456 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700457 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000458 msg = "OpAvgPool2d: " + msg;
459 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700460 return 1;
461 }
462
463 return 0;
464}
465
Eric Kunze830add42022-01-25 22:56:46 -0800466// This calculates the number of padding elements used for each location along an axis
467// Average pooling only divides by the number of elements used, not including padding.
468// This function uses left/right, but is also used for vertical padding with top/bottom
James Ward8b390432022-08-12 20:48:56 +0100469template <DType Dtype, DType AccDtype>
470ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
Eric Kunzee5e26762020-10-13 16:11:07 -0700471{
472 ETensor1<int32_t> result(out_size);
473
Eric Kunzee5e26762020-10-13 16:11:07 -0700474 result.setConstant(kernel_size);
475
Eric Kunze830add42022-01-25 22:56:46 -0800476 // adjust divisors on the left side for padding
477 // We start at the leftmost output element, and remove pad_left - (index * stride) elements
478 // until we have no more padding being used
Eric Kunze67a91552022-02-02 11:27:21 -0800479 for(int index = 0; (index <= pad_left / stride) && (index < out_size); index++) {
Eric Kunze830add42022-01-25 22:56:46 -0800480 int32_t adjust = pad_left - (index * stride);
481 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700482 }
483
Eric Kunze830add42022-01-25 22:56:46 -0800484 // The process repeats on the right side. Padding starts taking effect as we
485 // near the rightmost input element. The first output element which touches
486 // padding is defined in the initialization of index below. Then we keep moving
487 // to the right, increasing padding until we get to the last output element.
488 int index = std::max(0, ((pad_left + in_size - kernel_size) / stride) + 1);
489 for (; index < out_size; index++) {
490 int32_t adjust = ((index * stride) + kernel_size) - (pad_left + in_size);
491 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700492 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700493 return result;
494}
495
496// assuming input and output tensor have same scales like tflite reference
497// so no need to scale input and output
James Ward8b390432022-08-12 20:48:56 +0100498template <DType Dtype, DType AccDtype>
499int OpAvgPool2d<Dtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700500{
501 int in_batch = this->in->getShape()[0];
502 int in_height = this->in->getShape()[1];
503 int in_width = this->in->getShape()[2];
504 int in_channels = this->in->getShape()[3];
505
506 int out_batch = this->out->getShape()[0];
507 int out_height = this->out->getShape()[1];
508 int out_width = this->out->getShape()[2];
509 int out_channels = this->out->getShape()[3];
510
Kevin Chengacb550f2021-06-29 15:32:19 -0700511 ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
512 ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700513
TatWai Chong86c403b2022-06-06 20:46:01 -0700514 int pad_top = this->attribute->pad()[0];
515 int pad_bottom = this->attribute->pad()[1];
516 int pad_left = this->attribute->pad()[2];
517 int pad_right = this->attribute->pad()[3];
Eric Kunzee5e26762020-10-13 16:11:07 -0700518 int kernel_h = this->attribute->kernel()[0];
519 int kernel_w = this->attribute->kernel()[1];
520 int stride_h = this->attribute->stride()[0];
521 int stride_w = this->attribute->stride()[1];
522
James Ward8b390432022-08-12 20:48:56 +0100523 tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
524
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 DEBUG_INFO(OP,
526 "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
James Ward8b390432022-08-12 20:48:56 +0100527 "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
Eric Kunzee5e26762020-10-13 16:11:07 -0700528 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
James Ward8b390432022-08-12 20:48:56 +0100529 kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700530
531 Eigen::array<Eigen::Index, 2> im2col_input_dims;
532 im2col_input_dims[0] = kernel_h * kernel_w;
533 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
534
535 Eigen::array<Eigen::Index, 4> col2im_output_dims;
536 col2im_output_dims[0] = out_batch;
537 col2im_output_dims[1] = out_height;
538 col2im_output_dims[2] = out_width;
539 col2im_output_dims[3] = out_channels;
540
TatWai Chong86c403b2022-06-06 20:46:01 -0700541 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
542 pad[0] = std::make_pair(0, 0);
543 pad[1] = std::make_pair(pad_top, pad_bottom);
544 pad[2] = std::make_pair(pad_left, pad_right);
545 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700546
547 ETensor4<InEigenType> input_val = this->in->getTensor();
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000548 if (Dtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700549 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000550 input_val = input_val - (InEigenType)attribute->input_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700551 }
552
TatWai Chong86c403b2022-06-06 20:46:01 -0700553 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700554
555 // assuming input and output have same scales
556 // so input and output scaling is not required
557 // TODO: check if this assumption TOSA made
558
559 // extract_image_patches() output [N, KH, KW, H * W, C]
560 // transpose to [KH, KW, N, H * W, C]
561 // reshape to [KH * KW, N * H * W * C]
562 ETensor2<InEigenType> input_extract_patches =
563 input_padded.extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID)
564 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
565 .reshape(im2col_input_dims);
566
567 // 1D result with [N * H * W * C]
568 ETensor1<AccEigenType> out_1d(this->out->getElementCount());
569 out_1d.setZero();
570
571 // sum pool
572 for (size_t i = 0; i < this->out->getElementCount(); i++)
573 {
574 for (int32_t j = 0; j < kernel_h * kernel_w; j++)
575 {
576 out_1d(i) += (AccEigenType)input_extract_patches(j, i);
577 }
578 }
579
580 // reshape result to [N, H, W, C] and divide with div_map
581 ETensor4<AccEigenType> sum = out_1d.reshape(col2im_output_dims);
582
583 // calculate 1d height/width div_map (number of elements this pooling window covers)
584 // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C]
TatWai Chong86c403b2022-06-06 20:46:01 -0700585 ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h, pad_top, pad_bottom);
586 ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700587 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
588 Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
589
James Ward24dbc422022-10-19 12:20:31 +0100590 ETensor2<int32_t> dm2_w = div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width });
591 ETensor2<int32_t> dm2_h = div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 });
Eric Kunzee5e26762020-10-13 16:11:07 -0700592 ETensor4<int32_t> div_map =
James Ward24dbc422022-10-19 12:20:31 +0100593 dm2_h.contract(dm2_w, contract_dims)
Eric Kunzee5e26762020-10-13 16:11:07 -0700594 .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
595 .broadcast(bcast);
James Ward24dbc422022-10-19 12:20:31 +0100596 if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16)
Eric Kunzee5e26762020-10-13 16:11:07 -0700597 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700598 try
599 {
600 this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
601 int32_t multiplier, shift;
602 TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift);
Eric Kunzee5e26762020-10-13 16:11:07 -0700603
Kevin Chengacb550f2021-06-29 15:32:19 -0700604 return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false);
605 });
606 }
607 catch (std::string desc)
608 {
609 REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
610 }
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000611 this->out->getTensor() = this->out->getTensor() + (OutEigenType)(attribute->output_zp());
Eric Kunzee5e26762020-10-13 16:11:07 -0700612 this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
613 this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
614 }
615 else
616 {
James Ward24dbc422022-10-19 12:20:31 +0100617 // Case for float-types
Eric Kunzee5e26762020-10-13 16:11:07 -0700618 this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
619 }
620
621 return GraphNode::eval();
622}
623
James Wardd34b3fc2023-01-18 14:51:25 +0000624template <DType InDtype, DType WeightDtype, DType OutDtype>
625OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700626 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700627 uint64_t id_)
628 : GraphNode(sgt_, Op_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700629{
630 setRequiredOperands(3, 1);
631 setRequiredRank(4);
632
Kevin Cheng93a16282021-08-31 16:14:03 -0700633 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700634}
635
James Wardd34b3fc2023-01-18 14:51:25 +0000636template <DType InDtype, DType WeightDtype, DType OutDtype>
637OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700638{
639 if (attribute)
640 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700641}
642
James Wardd34b3fc2023-01-18 14:51:25 +0000643template <DType InDtype, DType WeightDtype, DType OutDtype>
644int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700645{
646 if (validateRequiredOperands())
647 return 1;
648
649 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
650 {
651 return 1;
652 }
653
654 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
655 if (inputs[2]->getRank() != 1)
656 {
657 printNodeValidationError("OpConv2d: bias tensor must be rank 1");
658 }
659
James Wardd34b3fc2023-01-18 14:51:25 +0000660 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +0100661 "OpConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700662
Eric Kunzee5e26762020-10-13 16:11:07 -0700663 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
664 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
665 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100666 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700667
Kevin Cheng9fe17242021-11-10 01:04:39 +0000668 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000669 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100670 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700671 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000672 msg = "OpConv2d: " + msg;
673 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700674 return 1;
675 }
676
Eric Kunzee5e26762020-10-13 16:11:07 -0700677 return 0;
678}
679
James Wardd34b3fc2023-01-18 14:51:25 +0000680template <DType InDtype, DType WeightDtype, DType OutDtype>
681int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700682{
683 int in_batch = this->input->getShape()[0];
684 int in_height = this->input->getShape()[1];
685 int in_width = this->input->getShape()[2];
686 int in_channels = this->input->getShape()[3];
687
688 int f_out_channels = this->weight->getShape()[0];
689 int f_height = this->weight->getShape()[1];
690 int f_width = this->weight->getShape()[2];
691 int f_in_channels = this->weight->getShape()[3];
692
693 int b_out_channels = this->bias->getShape()[0];
694
695 int out_batch = this->output->getShape()[0];
696 int out_height = this->output->getShape()[1];
697 int out_width = this->output->getShape()[2];
698 int out_channels = this->output->getShape()[3];
699
Kevin Chengacb550f2021-06-29 15:32:19 -0700700 ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
701 ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
702 in_channels);
703 ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
704 out_channels);
705 ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700706
TatWai Chong86c403b2022-06-06 20:46:01 -0700707 int pad_top = this->attribute->pad()[0];
708 int pad_bottom = this->attribute->pad()[1];
709 int pad_left = this->attribute->pad()[2];
710 int pad_right = this->attribute->pad()[3];
711
Eric Kunzee5e26762020-10-13 16:11:07 -0700712 int stride_h = this->attribute->stride()[0];
713 int stride_w = this->attribute->stride()[1];
714 int dilation_h = this->attribute->dilation()[0];
715 int dilation_w = this->attribute->dilation()[1];
716
717 DEBUG_INFO(OP,
718 "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000719 "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -0700720 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
TatWai Chong86c403b2022-06-06 20:46:01 -0700721 out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
James Wardd34b3fc2023-01-18 14:51:25 +0000722 pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700723
724 // GEMM-conv2d, left matrix is input, right matrix is weight
725 Eigen::array<Eigen::Index, 2> im2col_input_dims;
726 im2col_input_dims[0] = out_batch * out_height * out_width;
727 im2col_input_dims[1] = f_height * f_width * f_in_channels;
728
729 Eigen::array<Eigen::Index, 2> im2col_weight_dims;
730 im2col_weight_dims[0] = f_height * f_width * f_in_channels;
731 im2col_weight_dims[1] = f_out_channels;
732
733 Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
734 bias_reshaped_dims[0] = 1;
735 bias_reshaped_dims[1] = b_out_channels;
736
737 Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
738 weight_zp_bcast_dims[0] = f_height;
739 weight_zp_bcast_dims[1] = f_width;
740 weight_zp_bcast_dims[2] = f_in_channels;
741
742 Eigen::array<Eigen::Index, 2> bias_bcast_dims;
743 bias_bcast_dims[0] = out_batch * out_height * out_width;
744 bias_bcast_dims[1] = 1;
745
746 Eigen::array<Eigen::Index, 4> col2im_output_dims;
747 col2im_output_dims[0] = out_batch;
748 col2im_output_dims[1] = out_height;
749 col2im_output_dims[2] = out_width;
750 col2im_output_dims[3] = out_channels;
751
752 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
753
TatWai Chong86c403b2022-06-06 20:46:01 -0700754 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
755 pad[0] = std::make_pair(0, 0);
756 pad[1] = std::make_pair(pad_top, pad_bottom);
757 pad[2] = std::make_pair(pad_left, pad_right);
758 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700759
760 TIn input_val = this->input->getTensor();
761 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -0700762 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700763 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000764 input_val = input_val - (InEigenType)attribute->input_zp();
765 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700766 }
767
TatWai Chong86c403b2022-06-06 20:46:01 -0700768 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700769
770 // extract_image_patches() output [N, KH, KW, H * W, C]
771 // need to transpose to [N, H * W, KH, KW, C]
772 ETensor5<InEigenType> input_extract_patches =
773 input_padded
774 .extract_image_patches(f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID)
775 .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
776
777 // reshape input to [N * H * W, KH * KW * C]
778 ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
779
780 // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
781 ETensor2<WeightEigenType> im2col_weight =
James Ward8b390432022-08-12 20:48:56 +0100782 weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
Eric Kunzee5e26762020-10-13 16:11:07 -0700783
784 // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
785 // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
James Ward8b390432022-08-12 20:48:56 +0100786 ETensor2<OutEigenType> bias_2d = (this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700787
788 // output matrix is [N * H * W, C]
James Ward8b390432022-08-12 20:48:56 +0100789 ETensor2<OutEigenType> contracted_result =
790 (im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700791
792 // adding bias
James Ward8b390432022-08-12 20:48:56 +0100793 ETensor2<OutEigenType> biased_output = contracted_result + bias_2d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700794
795 // reshape back to [N, H, W, C]
796 this->output->getTensor() = biased_output.reshape(col2im_output_dims);
797
James Wardd34b3fc2023-01-18 14:51:25 +0000798 if (OutDtype == DType_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -0700799 {
James Ward8b390432022-08-12 20:48:56 +0100800 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
801 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700802 }
803
804 return GraphNode::eval();
805}
806
James Wardd34b3fc2023-01-18 14:51:25 +0000807template <DType InDtype, DType WeightDtype, DType OutDtype>
808OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
Kevin Cheng1533b852021-09-01 12:51:58 -0700809 TosaAttributeBase* attribute_,
Kevin Cheng1533b852021-09-01 12:51:58 -0700810 uint64_t id_)
811 : GraphNode(sgt_, Op_CONV3D, id_)
812{
813 setRequiredOperands(3, 1);
814 setRequiredRank(5);
815
816 INIT_ATTRIBUTE(Conv);
Kevin Cheng1533b852021-09-01 12:51:58 -0700817}
818
James Wardd34b3fc2023-01-18 14:51:25 +0000819template <DType InDtype, DType WeightDtype, DType OutDtype>
820OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
Kevin Cheng1533b852021-09-01 12:51:58 -0700821{
822 if (attribute)
823 delete attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700824}
825
James Wardd34b3fc2023-01-18 14:51:25 +0000826template <DType InDtype, DType WeightDtype, DType OutDtype>
827int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Kevin Cheng1533b852021-09-01 12:51:58 -0700828{
829 if (validateRequiredOperands())
830 return 1;
831
832 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
833 {
834 return 1;
835 }
836
837 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
838 if (inputs[2]->getRank() != 1)
839 {
840 printNodeValidationError("OpConv3d: bias tensor must be rank 1");
841 }
842
James Wardd34b3fc2023-01-18 14:51:25 +0000843 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +0100844 "OpConv3d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700845
Kevin Cheng1533b852021-09-01 12:51:58 -0700846 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
847 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
848 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100849 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng1533b852021-09-01 12:51:58 -0700850
Kevin Cheng9fe17242021-11-10 01:04:39 +0000851 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000852 if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100853 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Kevin Cheng1533b852021-09-01 12:51:58 -0700854 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000855 msg = "OpConv3d: " + msg;
856 printNodeValidationError(msg.c_str());
Kevin Cheng1533b852021-09-01 12:51:58 -0700857 return 1;
858 }
859
Kevin Cheng1533b852021-09-01 12:51:58 -0700860 return 0;
861}
862
James Wardd34b3fc2023-01-18 14:51:25 +0000863template <DType InDtype, DType WeightDtype, DType OutDtype>
864int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
Kevin Cheng1533b852021-09-01 12:51:58 -0700865{
866 int in_batch = this->input->getShape()[0];
867 int in_depth = this->input->getShape()[1];
868 int in_height = this->input->getShape()[2];
869 int in_width = this->input->getShape()[3];
870 int in_channels = this->input->getShape()[4];
871
872 int f_out_channels = this->weight->getShape()[0];
873 int f_depth = this->weight->getShape()[1];
874 int f_height = this->weight->getShape()[2];
875 int f_width = this->weight->getShape()[3];
876 int f_in_channels = this->weight->getShape()[4];
877
878 int b_out_channels = this->bias->getShape()[0];
879
880 int out_batch = this->output->getShape()[0];
881 int out_depth = this->output->getShape()[1];
882 int out_height = this->output->getShape()[2];
883 int out_width = this->output->getShape()[3];
884 int out_channels = this->output->getShape()[4];
885
886 ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
887 ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
888 in_channels);
889 ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
890 out_channels);
891 ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels);
892
TatWai Chong86c403b2022-06-06 20:46:01 -0700893 int pad_d0 = this->attribute->pad()[0];
894 int pad_d1 = this->attribute->pad()[1];
895 int pad_top = this->attribute->pad()[2];
896 int pad_bottom = this->attribute->pad()[3];
897 int pad_left = this->attribute->pad()[4];
898 int pad_right = this->attribute->pad()[5];
899
Kevin Cheng1533b852021-09-01 12:51:58 -0700900 int stride_d = this->attribute->stride()[0];
901 int stride_h = this->attribute->stride()[1];
902 int stride_w = this->attribute->stride()[2];
TatWai Chong86c403b2022-06-06 20:46:01 -0700903
Kevin Cheng1533b852021-09-01 12:51:58 -0700904 int dilation_d = this->attribute->dilation()[0];
905 int dilation_h = this->attribute->dilation()[1];
906 int dilation_w = this->attribute->dilation()[2];
907
908 DEBUG_INFO(
909 OP,
910 "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000911 "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
Kevin Cheng1533b852021-09-01 12:51:58 -0700912 in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
913 out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h,
James Wardd34b3fc2023-01-18 14:51:25 +0000914 dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
Kevin Cheng1533b852021-09-01 12:51:58 -0700915
TatWai Chong86c403b2022-06-06 20:46:01 -0700916 Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
917 pad[0] = std::make_pair(0, 0);
918 pad[1] = std::make_pair(pad_d0, pad_d1);
919 pad[2] = std::make_pair(pad_top, pad_bottom);
920 pad[3] = std::make_pair(pad_left, pad_right);
921 pad[4] = std::make_pair(0, 0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700922
923 TIn input_val = this->input->getTensor();
924 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -0700925 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Kevin Cheng1533b852021-09-01 12:51:58 -0700926 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000927 input_val = input_val - (InEigenType)attribute->input_zp();
928 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Kevin Cheng1533b852021-09-01 12:51:58 -0700929 }
930
TatWai Chong86c403b2022-06-06 20:46:01 -0700931 ETensor5<InEigenType> input_padded = input_val.pad(pad);
Kevin Cheng1533b852021-09-01 12:51:58 -0700932
933 // 1. initialize with bias
934 Eigen::array<Eigen::Index, 5> reshape_dim;
935 reshape_dim.fill(1);
936 reshape_dim[4] = b_out_channels;
937
938 Eigen::array<Eigen::Index, 5> bcast;
939 bcast[0] = out_batch;
940 bcast[1] = out_depth;
941 bcast[2] = out_height;
942 bcast[3] = out_width;
943 bcast[4] = 1;
944 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
945
946 // 2. direct convolution
James Ward8b390432022-08-12 20:48:56 +0100947 AccEigenType acc(0.0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700948 int d_idx, h_idx, w_idx;
949
950 for (int ob = 0; ob < out_batch; ob++)
951 {
952 for (int od = 0; od < out_depth; od++)
953 {
954 for (int oh = 0; oh < out_height; oh++)
955 {
956 for (int ow = 0; ow < out_width; ow++)
957 {
958 for (int oc = 0; oc < out_channels; oc++)
959 {
Eric Kunze7edb34c2022-05-16 17:34:40 -0700960 // Initialize accumulator with bias value
James Ward8b390432022-08-12 20:48:56 +0100961 acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc);
Kevin Cheng1533b852021-09-01 12:51:58 -0700962 for (int fd = 0; fd < f_depth; fd++)
963 {
964 d_idx = od * stride_d + fd * dilation_d;
965 for (int fh = 0; fh < f_height; fh++)
966 {
967 h_idx = oh * stride_h + fh * dilation_h;
968 for (int fw = 0; fw < f_width; fw++)
969 {
970 w_idx = ow * stride_w + fw * dilation_w;
971 for (int ic = 0; ic < in_channels; ic++)
972 {
973 acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
974 (AccEigenType)weight_val(oc, fd, fh, fw, ic));
975 }
976 }
977 }
978 }
James Ward8b390432022-08-12 20:48:56 +0100979 this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc;
Kevin Cheng1533b852021-09-01 12:51:58 -0700980 }
981 }
982 }
983 }
984 }
985
James Wardd34b3fc2023-01-18 14:51:25 +0000986 if (OutDtype == DType_INT48)
Kevin Cheng1533b852021-09-01 12:51:58 -0700987 {
James Ward8b390432022-08-12 20:48:56 +0100988 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
989 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Kevin Cheng1533b852021-09-01 12:51:58 -0700990 }
991
992 return GraphNode::eval();
993}
994
James Wardd34b3fc2023-01-18 14:51:25 +0000995template <DType InDtype, DType WeightDtype, DType OutDtype>
996OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700997 TosaAttributeBase* attribute_,
Eric Kunzee5e26762020-10-13 16:11:07 -0700998 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700999 : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001000{
1001 setRequiredOperands(3, 1);
1002 setRequiredRank(4);
1003
Kevin Cheng93a16282021-08-31 16:14:03 -07001004 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001005}
1006
James Wardd34b3fc2023-01-18 14:51:25 +00001007template <DType InDtype, DType WeightDtype, DType OutDtype>
1008OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001009{
1010 if (attribute)
1011 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001012}
1013
James Wardd34b3fc2023-01-18 14:51:25 +00001014template <DType InDtype, DType WeightDtype, DType OutDtype>
1015int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001016{
1017 if (validateRequiredOperands())
1018 return 1;
1019
1020 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1021 {
1022 return 1;
1023 }
1024
1025 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
1026 if (inputs[2]->getRank() != 1)
1027 {
1028 printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
1029 }
1030
James Wardd34b3fc2023-01-18 14:51:25 +00001031 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +01001032 "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001033
Eric Kunzee5e26762020-10-13 16:11:07 -07001034 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1035 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1036 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001037 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001038
Kevin Cheng9fe17242021-11-10 01:04:39 +00001039 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001040 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001041 weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001042 {
Kevin Cheng9fe17242021-11-10 01:04:39 +00001043 msg = "OpDepthwiseConv2d: " + msg;
1044 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001045 return 1;
1046 }
1047
Eric Kunzee5e26762020-10-13 16:11:07 -07001048 return 0;
1049}
1050
James Wardd34b3fc2023-01-18 14:51:25 +00001051template <DType InDtype, DType WeightDtype, DType OutDtype>
1052int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001053{
1054 int in_batch = this->input->getShape()[0];
1055 int in_height = this->input->getShape()[1];
1056 int in_width = this->input->getShape()[2];
1057 int in_channels = this->input->getShape()[3];
1058
1059 int f_height = this->weight->getShape()[0];
1060 int f_width = this->weight->getShape()[1];
1061 int f_in_channels = this->weight->getShape()[2];
1062 int f_multiplier = this->weight->getShape()[3];
1063
1064 int b_out_channels = this->bias->getShape()[0];
1065
1066 int out_batch = this->output->getShape()[0];
1067 int out_height = this->output->getShape()[1];
1068 int out_width = this->output->getShape()[2];
1069 int out_channels = this->output->getShape()[3];
1070
Kevin Chengacb550f2021-06-29 15:32:19 -07001071 ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1072 ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1073 in_channels);
1074 ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
1075 in_channels * f_multiplier, out_channels);
1076 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
1077 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001078
TatWai Chong86c403b2022-06-06 20:46:01 -07001079 int pad_top = this->attribute->pad()[0];
1080 int pad_bottom = this->attribute->pad()[1];
1081 int pad_left = this->attribute->pad()[2];
1082 int pad_right = this->attribute->pad()[3];
1083
Eric Kunzee5e26762020-10-13 16:11:07 -07001084 int stride_h = this->attribute->stride()[0];
1085 int stride_w = this->attribute->stride()[1];
1086 int dilation_h = this->attribute->dilation()[0];
1087 int dilation_w = this->attribute->dilation()[1];
1088
1089 DEBUG_INFO(OP,
1090 "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00001091 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -07001092 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
TatWai Chong86c403b2022-06-06 20:46:01 -07001093 out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
James Wardd34b3fc2023-01-18 14:51:25 +00001094 pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001095
TatWai Chong86c403b2022-06-06 20:46:01 -07001096 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1097 pad[0] = std::make_pair(0, 0);
1098 pad[1] = std::make_pair(pad_top, pad_bottom);
1099 pad[2] = std::make_pair(pad_left, pad_right);
1100 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001101
1102 TIn input_val = this->input->getTensor();
1103 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -07001104 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001105 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001106 input_val = input_val - (InEigenType)attribute->input_zp();
1107 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001108 }
1109
TatWai Chong86c403b2022-06-06 20:46:01 -07001110 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
1112 // GEMM doesn't fit well with DepthwiseConv2d
TatWai Chong86c403b2022-06-06 20:46:01 -07001113 // 1. use extract_image_patches() to handle stride/dilation/pad
Eric Kunzee5e26762020-10-13 16:11:07 -07001114 // 2. perform direct convolution
1115
1116 // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
1117 ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
1118 f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID);
1119
1120 Eigen::array<Eigen::Index, 4> reshape_dim;
1121 reshape_dim.fill(1);
1122 reshape_dim[3] = b_out_channels;
1123
1124 Eigen::array<Eigen::Index, 4> bcast;
1125 bcast[0] = out_batch;
1126 bcast[1] = out_height;
1127 bcast[2] = out_width;
1128 bcast[3] = 1;
1129
1130 // initialize with bias
1131 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
1132
1133 // 2. direct depthwise convolution
1134 for (int ob = 0; ob < out_batch; ob++)
1135 {
1136 for (int oh = 0; oh < out_height; oh++)
1137 {
1138 for (int ow = 0; ow < out_width; ow++)
1139 {
1140 for (int ic = 0; ic < in_channels; ic++)
1141 {
1142 for (int cm = 0; cm < f_multiplier; cm++)
1143 {
1144 for (int fh = 0; fh < f_height; fh++)
1145 {
1146 for (int fw = 0; fw < f_width; fw++)
1147 {
James Ward8b390432022-08-12 20:48:56 +01001148 // Perform multiplication in AccEigenType then cast to OutEigenType
Eric Kunzee5e26762020-10-13 16:11:07 -07001149 this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
James Ward8b390432022-08-12 20:48:56 +01001150 (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
1151 (AccEigenType)weight_val(fh, fw, ic, cm));
Eric Kunzee5e26762020-10-13 16:11:07 -07001152 }
1153 }
1154 }
1155 }
1156 }
1157 }
1158 }
1159
James Wardd34b3fc2023-01-18 14:51:25 +00001160 if (OutDtype == DType_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001161 {
James Ward8b390432022-08-12 20:48:56 +01001162 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1163 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001164 }
1165
1166 return GraphNode::eval();
1167}
1168
James Wardd34b3fc2023-01-18 14:51:25 +00001169template <DType InDtype, DType WeightDtype, DType OutDtype>
1170OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001171 TosaAttributeBase* attribute_,
Eric Kunzee5e26762020-10-13 16:11:07 -07001172 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001173 : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001174{
1175 setRequiredOperands(3, 1);
1176 setRequiredRank(2);
1177
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001178 INIT_ATTRIBUTE(FullyConnected);
Eric Kunzee5e26762020-10-13 16:11:07 -07001179}
1180
James Wardd34b3fc2023-01-18 14:51:25 +00001181template <DType InDtype, DType WeightDtype, DType OutDtype>
1182OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
Eric Kunzee5e26762020-10-13 16:11:07 -07001183{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001184 if (attribute)
1185 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001186}
1187
James Wardd34b3fc2023-01-18 14:51:25 +00001188template <DType InDtype, DType WeightDtype, DType OutDtype>
1189int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001190{
1191 if (validateRequiredOperands())
1192 return 1;
1193
1194 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1195 {
1196 return 1;
1197 }
1198
1199 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1200 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1201 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
1202
1203 if (input->getShape()[1] != weight->getShape()[1])
1204 {
1205 printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
1206 return 1;
1207 }
1208
1209 if (weight->getShape()[0] != bias->getShape()[0])
1210 {
1211 printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
1212 return 1;
1213 }
1214
James Wardd34b3fc2023-01-18 14:51:25 +00001215 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +01001216 "OpFullyConnected: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001217
James Ward8b390432022-08-12 20:48:56 +01001218 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001219
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001220 ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
1221 ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001222
Eric Kunzee5e26762020-10-13 16:11:07 -07001223 return 0;
1224}
1225
James Wardd34b3fc2023-01-18 14:51:25 +00001226template <DType InDtype, DType WeightDtype, DType OutDtype>
1227int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001228{
1229 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1230 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1231
1232 Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
1233
1234 Eigen::array<Eigen::Index, 2> bias_reshape;
1235 bias_reshape[0] = 1;
1236 bias_reshape[1] = this->bias->getShape()[0];
1237
1238 Eigen::array<Eigen::Index, 2> bias_bcast;
1239 bias_bcast[0] = this->input->getShape()[0];
1240 bias_bcast[1] = 1;
1241
1242 TIn input_val = this->input->getTensor();
1243 TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
Eric Kunzef7337832022-06-17 08:19:12 -07001244 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001245 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001246 input_val = input_val - (InEigenType)attribute->input_zp();
1247 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001248 }
1249
1250 this->output->getTensor() =
James Ward8b390432022-08-12 20:48:56 +01001251 input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() +
1252 this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001253
James Wardd34b3fc2023-01-18 14:51:25 +00001254 if (OutDtype == DType_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001255 {
James Ward8b390432022-08-12 20:48:56 +01001256 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1257 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001258 }
1259 return GraphNode::eval();
1260}
1261
James Wardd34b3fc2023-01-18 14:51:25 +00001262template <DType Dtype, DType OutDtype>
1263OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001264 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001265 uint64_t id_)
1266 : GraphNode(sgt_, Op_MATMUL, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001267{
1268 setRequiredOperands(2, 1);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001269 setRequiredRank(3);
Eric Kunzee5e26762020-10-13 16:11:07 -07001270
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001271 INIT_ATTRIBUTE(MatMul);
Eric Kunzee5e26762020-10-13 16:11:07 -07001272}
1273
James Wardd34b3fc2023-01-18 14:51:25 +00001274template <DType Dtype, DType OutDtype>
1275OpMatMul<Dtype, OutDtype>::~OpMatMul()
Eric Kunzee5e26762020-10-13 16:11:07 -07001276{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001277 if (attribute)
1278 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001279}
1280
James Wardd34b3fc2023-01-18 14:51:25 +00001281template <DType Dtype, DType OutDtype>
1282int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001283{
1284 if (validateRequiredOperands())
1285 return 1;
1286
1287 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1288 {
1289 return 1;
1290 }
1291
James Wardd34b3fc2023-01-18 14:51:25 +00001292 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +01001293 "OpMatMul: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001294
Kevin Cheng2d60f002021-06-09 14:18:32 -07001295 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1296 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
James Ward8b390432022-08-12 20:48:56 +01001297 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001298
Kevin Cheng2d60f002021-06-09 14:18:32 -07001299 ASSERT_MEM(a && b && output);
1300
1301 // a: [N, H, C]
1302 // b: [N, C, W]
1303 // c: [N, H, W]
1304
1305 // Check N
1306 if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001307 {
Kevin Cheng2d60f002021-06-09 14:18:32 -07001308 printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
Eric Kunzee5e26762020-10-13 16:11:07 -07001309 return 1;
1310 }
Kevin Cheng2d60f002021-06-09 14:18:32 -07001311 N = a->getShape()[0];
Eric Kunzee5e26762020-10-13 16:11:07 -07001312
Kevin Cheng2d60f002021-06-09 14:18:32 -07001313 // Check C
1314 if (a->getShape()[2] != b->getShape()[1])
1315 {
1316 printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
1317 return 1;
1318 }
1319 C = a->getShape()[2];
1320
1321 // Check H
1322 if (a->getShape()[1] != output->getShape()[1])
1323 {
1324 printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
1325 return 1;
1326 }
1327 H = a->getShape()[1];
1328
1329 // Check W
1330 if (b->getShape()[2] != output->getShape()[2])
1331 {
1332 printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
1333 return 1;
1334 }
1335 W = b->getShape()[2];
Eric Kunzee5e26762020-10-13 16:11:07 -07001336
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001337 ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data");
1338 ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001339
Eric Kunzee5e26762020-10-13 16:11:07 -07001340 return 0;
1341}
1342
James Wardd34b3fc2023-01-18 14:51:25 +00001343template <DType Dtype, DType OutDtype>
1344int OpMatMul<Dtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001345{
1346 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1347 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1348
1349 TIn a_val = this->a->getTensor();
1350 TIn b_val = this->b->getTensor();
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001351 if (Dtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001352 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001353 a_val = a_val - (InEigenType)attribute->a_zp();
1354 b_val = b_val - (InEigenType)attribute->b_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001355 }
1356
Kevin Cheng2d60f002021-06-09 14:18:32 -07001357 Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
1358 Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
1359 Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
1360
1361 Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
1362 Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
1363
1364 Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
1365 Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
1366
1367 // Iterate N dimension.
1368 for (int i = 0; i < N; i++)
1369 {
1370 a_begin_array[0] = i;
1371 b_begin_array[0] = i;
1372
1373 TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
1374 TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
1375 TAccRank2 output_rank2_val =
1376 a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
James Ward8b390432022-08-12 20:48:56 +01001377 TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast<OutEigenType>();
Kevin Cheng2d60f002021-06-09 14:18:32 -07001378 if (i == 0)
1379 {
1380 this->output->getTensor() = output_rank3_val;
1381 }
1382 else
1383 {
James Ward8b390432022-08-12 20:48:56 +01001384 TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001385 this->output->getTensor() = temp;
1386 }
1387 }
Eric Kunzee5e26762020-10-13 16:11:07 -07001388
James Wardd34b3fc2023-01-18 14:51:25 +00001389 if (OutDtype == DType_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001390 {
James Ward8b390432022-08-12 20:48:56 +01001391 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1392 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001393 }
1394
1395 return GraphNode::eval();
1396}
1397
1398template <DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -07001399OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
1400 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001401 uint64_t id_)
1402 : GraphNode(sgt_, Op_MAX_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001403{
1404 setRequiredOperands(1, 1);
1405 setRequiredRank(4);
1406
Kevin Cheng93a16282021-08-31 16:14:03 -07001407 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -07001408}
1409
1410template <DType Dtype>
1411OpMaxPool2d<Dtype>::~OpMaxPool2d()
1412{
1413 if (attribute)
1414 delete attribute;
1415}
1416
1417template <DType Dtype>
1418int OpMaxPool2d<Dtype>::checkTensorAttributes()
1419{
1420 if (validateRequiredOperands())
1421 return 1;
1422
1423 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
1424 {
1425 return 1;
1426 }
1427
1428 if (inputs[0]->matchType(*outputs[0]))
1429 {
1430 printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
1431 return 1;
1432 }
1433
1434 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1435 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1436
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001437 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +00001438 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001439 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001440 msg = "OpMaxPool2d: " + msg;
1441 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001442 return 1;
1443 }
1444
1445 return 0;
1446}
1447
1448template <DType Dtype>
1449int OpMaxPool2d<Dtype>::eval()
1450{
1451 int in_batch = this->in->getShape()[0];
1452 int in_height = this->in->getShape()[1];
1453 int in_width = this->in->getShape()[2];
1454 int in_channels = this->in->getShape()[3];
1455
1456 int out_batch = this->out->getShape()[0];
1457 int out_height = this->out->getShape()[1];
1458 int out_width = this->out->getShape()[2];
1459 int out_channels = this->out->getShape()[3];
1460
Kevin Chengacb550f2021-06-29 15:32:19 -07001461 ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1462 ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001463
TatWai Chong86c403b2022-06-06 20:46:01 -07001464 int pad_top = this->attribute->pad()[0];
1465 int pad_bottom = this->attribute->pad()[1];
1466 int pad_left = this->attribute->pad()[2];
1467 int pad_right = this->attribute->pad()[3];
1468
Eric Kunzee5e26762020-10-13 16:11:07 -07001469 int kernel_h = this->attribute->kernel()[0];
1470 int kernel_w = this->attribute->kernel()[1];
1471 int stride_h = this->attribute->stride()[0];
1472 int stride_w = this->attribute->stride()[1];
1473
1474 DEBUG_INFO(OP,
1475 "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
TatWai Chong86c403b2022-06-06 20:46:01 -07001476 "stride=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -07001477 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
TatWai Chong86c403b2022-06-06 20:46:01 -07001478 kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001479
1480 Eigen::array<Eigen::Index, 2> im2col_input_dims;
1481 im2col_input_dims[0] = kernel_h * kernel_w;
1482 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
1483
1484 Eigen::array<Eigen::Index, 4> col2im_output_dims;
1485 col2im_output_dims[0] = out_batch;
1486 col2im_output_dims[1] = out_height;
1487 col2im_output_dims[2] = out_width;
1488 col2im_output_dims[3] = out_channels;
1489
TatWai Chong86c403b2022-06-06 20:46:01 -07001490 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1491 pad[0] = std::make_pair(0, 0);
1492 pad[1] = std::make_pair(pad_top, pad_bottom);
1493 pad[2] = std::make_pair(pad_left, pad_right);
1494 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001495
TatWai Chong86c403b2022-06-06 20:46:01 -07001496 ETensor4<InEigenType> input_padded = this->in->getTensor().pad(pad, std::numeric_limits<InEigenType>::lowest());
Eric Kunzee5e26762020-10-13 16:11:07 -07001497
1498 // extract_image_patches() output [N, KH, KW, H * W, C]
1499 // transpose to [KH, KW, N, H * W, C]
1500 // reshape to [KH * KW, N * H * W * C]
1501 //
1502 // Set the padding value to be the most negative value that can be
1503 // represented by the datatype to ensure that any padding values will be equal
1504 // to or smaller than the actual maximum in the KH x KW patch.
1505 ETensor2<InEigenType> input_extract_patches =
1506 input_padded
1507 .extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID,
1508 std::numeric_limits<InEigenType>::lowest())
1509 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
1510 .reshape(im2col_input_dims);
1511
1512 // Get the maximum of the KHxHW patches along axis 0
1513 Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
1514
1515 // 1D result with [N * H * W * C]
1516 ETensor1<OutEigenType> out_1d(this->out->getElementCount());
1517
1518 // index input_patches with argmax array should give the result
1519 for (size_t i = 0; i < this->out->getElementCount(); i++)
1520 {
1521 out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
1522 }
1523
1524 // reshape result to [N, H, W, C]
1525 this->out->getTensor() = out_1d.reshape(col2im_output_dims);
1526
1527 return GraphNode::eval();
1528}
1529
Luke Hutton261b7b62023-01-10 14:50:31 +00001530template <DType Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001531OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_,
1532 TosaAttributeBase* attribute_,
1533 uint64_t id_)
1534 : GraphNode(sgt_, Op_FFT2D, id_)
1535{
1536 setRequiredOperands(2, 2);
1537 setRequiredRank(3);
1538
1539 INIT_ATTRIBUTE(FFT);
1540}
1541
1542template <DType Dtype>
1543OpFFT2d<Dtype>::~OpFFT2d() {
1544 if (attribute)
1545 delete attribute;
1546}
1547
1548
1549template <DType Dtype>
1550int OpFFT2d<Dtype>::checkTensorAttributes()
1551{
1552 if (validateRequiredOperands())
1553 return 1;
1554
1555 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) ||
1556 validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1]))
1557 {
1558 return 1;
1559 }
1560
1561 if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) ||
1562 inputs[0]->matchType(*inputs[1]))
1563 {
1564 printNodeValidationError("OpFFT2d: input and output tensor type mismatch");
1565 return 1;
1566 }
1567
1568 in_real = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1569 in_imag = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
1570 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1571 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1572
1573 ASSERT_MEM(in_real && in_imag && out_real && out_imag);
1574
1575 std::string msg;
1576 if (check_fft_shape(in_real->getShape(), in_imag->getShape(),
1577 out_real->getShape(), out_imag->getShape(), msg))
1578 {
1579 msg = "OpFFT2d: " + msg;
1580 printNodeValidationError(msg.c_str());
1581 return 1;
1582 }
1583
1584 return 0;
1585}
1586
1587template <DType Dtype>
1588int OpFFT2d<Dtype>::eval()
1589{
1590 int in_real_batch = this->in_real->getShape()[0];
1591 int in_real_height = this->in_real->getShape()[1];
1592 int in_real_width = this->in_real->getShape()[2];
1593
1594 int in_imag_batch = this->in_imag->getShape()[0];
1595 int in_imag_height = this->in_imag->getShape()[1];
1596 int in_imag_width = this->in_imag->getShape()[2];
1597
1598 int out_real_batch = this->out_real->getShape()[0];
1599 int out_real_height = this->out_real->getShape()[1];
1600 int out_real_width = this->out_real->getShape()[2];
1601
1602 int out_imag_batch = this->out_imag->getShape()[0];
1603 int out_imag_height = this->out_imag->getShape()[1];
1604 int out_imag_width = this->out_imag->getShape()[2];
1605
1606 DEBUG_INFO(OP,
1607 "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]",
1608 in_real_batch, in_real_height, in_real_width,
1609 in_imag_batch, in_imag_height, in_imag_width,
1610 out_real_batch, out_real_height, out_real_width,
1611 out_imag_batch, out_imag_height, out_imag_width);
1612
1613 OutEigenType sum_real, sum_imag, a, sign_val = 1.0;
1614
1615 if (attribute->inverse()) {
1616 sign_val = -1.0;
1617 }
1618
1619 for (int n = 0; n < in_real_batch; n++)
1620 {
1621 for (int oy = 0; oy < out_real_height; oy++)
1622 {
1623 for (int ox = 0; ox < out_real_width; ox++)
1624 {
1625 sum_real = 0.0;
1626 sum_imag = 0.0;
1627 for (int iy = 0; iy < in_real_height; iy++)
1628 {
1629 for (int ix = 0; ix < in_real_width; ix++)
1630 {
1631 OutEigenType val_real = this->in_real->getTensor()(n, iy, ix);
1632 OutEigenType val_imag = this->in_imag->getTensor()(n, iy, ix);
1633 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
1634 a = sign_val * 2 * M_PI * ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width);
1635 sum_real += val_real * cos(a) + val_imag * sin(a);
1636 sum_imag += -val_real * sin(a) + val_imag * cos(a);
1637 }
1638 }
1639 this->out_real->getTensor()(n, oy, ox) = sum_real;
1640 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1641 }
1642 }
1643 }
1644
1645 return GraphNode::eval();
1646}
1647
1648template <DType Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001649OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
1650 TosaAttributeBase* attribute_,
1651 uint64_t id_)
1652 : GraphNode(sgt_, Op_RFFT2D, id_)
1653{
1654 setRequiredOperands(1, 2);
1655 setRequiredRank(3);
1656}
1657
1658template <DType Dtype>
1659OpRFFT2d<Dtype>::~OpRFFT2d() {}
1660
1661
1662template <DType Dtype>
1663int OpRFFT2d<Dtype>::checkTensorAttributes()
1664{
1665 if (validateRequiredOperands())
1666 return 1;
1667
1668 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) ||
1669 validateRequiredRank(outputs[1]))
1670 {
1671 return 1;
1672 }
1673
1674 if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
1675 {
1676 printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
1677 return 1;
1678 }
1679
1680 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1681 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1682 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1683
1684 ASSERT_MEM(in && out_real && out_imag);
1685
Luke Hutton57287132023-02-06 14:54:18 +00001686 std::string msg;
1687 if (check_fft_shape(in->getShape(), {},
1688 out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton261b7b62023-01-10 14:50:31 +00001689 {
Luke Hutton57287132023-02-06 14:54:18 +00001690 msg = "OpRFFT2d: " + msg;
1691 printNodeValidationError(msg.c_str());
Luke Hutton261b7b62023-01-10 14:50:31 +00001692 return 1;
1693 }
1694
1695 return 0;
1696}
1697
1698template <DType Dtype>
1699int OpRFFT2d<Dtype>::eval()
1700{
1701 int32_t in_batch = in->getShape()[0];
1702 int32_t in_height = in->getShape()[1];
1703 int32_t in_width = in->getShape()[2];
1704
1705 int32_t out_real_batch = out_real->getShape()[0];
1706 int32_t out_real_height = out_real->getShape()[1];
1707 int32_t out_real_width = out_real->getShape()[2];
1708
1709 int32_t out_imag_batch = out_imag->getShape()[0];
1710 int32_t out_imag_height = out_imag->getShape()[1];
1711 int32_t out_imag_width = out_imag->getShape()[2];
1712
1713 DEBUG_INFO(OP,
1714 "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], "
1715 "output_imag.shape=[%d,%d,%d]",
1716 in_batch, in_height, in_width,
1717 out_real_batch, out_real_height, out_real_width,
1718 out_imag_batch, out_imag_height, out_imag_width);
1719
1720 OutEigenType sum_real, sum_imag, a;
1721
1722 for (int n = 0; n < in_batch; n++)
1723 {
1724 for (int oy = 0; oy < out_real_height; oy++)
1725 {
1726 for (int ox = 0; ox < out_real_width; ox++)
1727 {
1728 sum_real = 0.0;
1729 sum_imag = 0.0;
1730 for (int iy = 0; iy < in_height; iy++)
1731 {
1732 for (int ix = 0; ix < in_width; ix++)
1733 {
1734 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
1735 a = 2 * M_PI * ((iy * (OutEigenType)oy) / in_height + (ix * (OutEigenType)ox) / in_width);
1736 sum_real += this->in->getTensor()(n, iy, ix) * cos(a);
1737 sum_imag += -this->in->getTensor()(n, iy, ix) * sin(a);
1738 }
1739 }
1740 this->out_real->getTensor()(n, oy, ox) = sum_real;
1741 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1742 }
1743 }
1744 }
1745
1746 return GraphNode::eval();
1747}
1748
James Wardd34b3fc2023-01-18 14:51:25 +00001749template <DType InDtype, DType WeightDtype, DType OutDtype>
1750OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
Kevin Chengcc61be32021-10-14 17:09:57 -07001751 TosaAttributeBase* attribute_,
Kevin Chengcc61be32021-10-14 17:09:57 -07001752 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001753 : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001754{
1755 setRequiredOperands(3, 1);
1756 setRequiredRank(4);
1757
Kevin Cheng93a16282021-08-31 16:14:03 -07001758 INIT_ATTRIBUTE(TransposeConv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001759}
1760
James Wardd34b3fc2023-01-18 14:51:25 +00001761template <DType InDtype, DType WeightDtype, DType OutDtype>
1762OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001763{
1764 if (attribute)
1765 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001766}
1767
James Wardd34b3fc2023-01-18 14:51:25 +00001768template <DType InDtype, DType WeightDtype, DType OutDtype>
1769int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001770{
1771 if (validateRequiredOperands())
1772 return 1;
1773
1774 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1775 {
1776 return 1;
1777 }
1778
James Wardd34b3fc2023-01-18 14:51:25 +00001779 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +01001780 "OpTransposeConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001781
Eric Kunzee5e26762020-10-13 16:11:07 -07001782 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1783 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1784 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001785 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001786
TatWai Chong24594f52022-06-08 00:48:04 -07001787 if (attribute->out_pad().size() != 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001788 {
TatWai Chong24594f52022-06-08 00:48:04 -07001789 printNodeValidationError("OpTransposeConv2d: illegal size for attribute out_pad");
Eric Kunzee5e26762020-10-13 16:11:07 -07001790 return 1;
1791 }
1792
1793 if (attribute->stride().size() != 2)
1794 {
1795 printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
1796 return 1;
1797 }
1798
Eric Kunzee5e26762020-10-13 16:11:07 -07001799 if (attribute->output_shape().size() != 4)
1800 {
1801 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1802 return 1;
1803 }
1804
Eric Kunzec1a97832022-07-01 16:56:09 -07001805
Kevin Cheng9fe17242021-11-10 01:04:39 +00001806
1807 for (int32_t i : attribute->stride())
1808 {
1809 if (i < 1)
1810 {
1811 printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one");
1812 return 1;
1813 }
1814 }
1815
Eric Kunzee5e26762020-10-13 16:11:07 -07001816 for (int d = 0; d < 4; d++)
1817 {
1818 if (attribute->output_shape()[d] != this->output->getShape()[d])
1819 {
1820 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1821 return 1;
1822 }
1823 }
1824
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001825 int32_t IH = input->getShape()[1];
1826 int32_t IW = input->getShape()[2];
1827 int32_t OH = output->getShape()[1];
1828 int32_t OW = output->getShape()[2];
1829
1830 int32_t stride_y = attribute->stride()[0];
1831 int32_t stride_x = attribute->stride()[1];
1832 int32_t kernel_h = weight->getShape()[1];
1833 int32_t kernel_w = weight->getShape()[2];
1834
TatWai Chong24594f52022-06-08 00:48:04 -07001835 int32_t out_pad_top = attribute->out_pad()[0];
1836 int32_t out_pad_bottom = attribute->out_pad()[1];
1837 int32_t out_pad_left = attribute->out_pad()[2];
1838 int32_t out_pad_right = attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001839
Eric Kunzec1a97832022-07-01 16:56:09 -07001840 for (size_t i = 0; i < attribute->out_pad().size(); i++)
1841 {
1842 ERROR_IF(attribute->out_pad()[i] <= -(weight->getShape()[(i / 2) + 1]), "OpTransposeConv2d: At least one out_pad value is larger than kernel size");
1843 }
1844
1845 int32_t H = (IH - 1) * stride_y + out_pad_top + out_pad_bottom + kernel_h;
1846 int32_t W = (IW - 1) * stride_x + out_pad_left + out_pad_right + kernel_w;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001847
1848 if ((OH != H) || (OW != W))
1849 {
1850 std::string msg = "OpTransposeConv2d: Mismatch between output shape provided and expected output shape (" +
1851 std::to_string(H) + "," +
1852 std::to_string(W) + ")";
1853 printNodeValidationError(msg.c_str());
1854 return 1;
1855 }
1856
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001857 ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
1858 ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001859
Eric Kunzee5e26762020-10-13 16:11:07 -07001860 return 0;
1861}
1862
James Wardd34b3fc2023-01-18 14:51:25 +00001863template <DType InDtype, DType WeightDtype, DType OutDtype>
1864int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001865{
1866 int in_batch = this->input->getShape()[0];
1867 int in_height = this->input->getShape()[1];
1868 int in_width = this->input->getShape()[2];
1869 int in_channels = this->input->getShape()[3];
1870
1871 int f_out_channels = this->weight->getShape()[0];
1872 int f_height = this->weight->getShape()[1];
1873 int f_width = this->weight->getShape()[2];
1874 int f_in_channels = this->weight->getShape()[3];
1875
1876 int b_out_channels = this->bias->getShape()[0];
1877
1878 int out_batch = this->output->getShape()[0];
1879 int out_height = this->output->getShape()[1];
1880 int out_width = this->output->getShape()[2];
1881 int out_channels = this->output->getShape()[3];
1882
TatWai Chong24594f52022-06-08 00:48:04 -07001883 int out_pad_top = this->attribute->out_pad()[0];
1884 int out_pad_bottom = this->attribute->out_pad()[1];
1885 int out_pad_left = this->attribute->out_pad()[2];
1886 int out_pad_right = this->attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001887
1888 int stride_h = this->attribute->stride()[0];
1889 int stride_w = this->attribute->stride()[1];
Eric Kunzee5e26762020-10-13 16:11:07 -07001890
Kevin Chengacb550f2021-06-29 15:32:19 -07001891 ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1892 ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1893 in_channels);
1894 ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
1895 f_out_channels, out_channels);
1896 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
1897 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001898
1899 DEBUG_INFO(OP,
1900 "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00001901 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001902 in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels,
TatWai Chong24594f52022-06-08 00:48:04 -07001903 out_batch, out_height, out_width, out_channels, stride_h, stride_w, out_pad_top,
James Wardd34b3fc2023-01-18 14:51:25 +00001904 out_pad_bottom, out_pad_left, out_pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001905
1906 TIn input_val = this->input->getTensor();
1907 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -07001908 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001909 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001910 input_val = input_val - (InEigenType)attribute->input_zp();
1911 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001912 }
1913
1914 Eigen::array<Eigen::Index, 4> reshape_dim;
1915 reshape_dim.fill(1);
1916 reshape_dim[3] = b_out_channels;
1917
1918 Eigen::array<Eigen::Index, 4> bcast;
1919 bcast[0] = out_batch;
1920 bcast[1] = out_height;
1921 bcast[2] = out_width;
1922 bcast[3] = 1;
1923
1924 // initialize with bias
1925 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
1926
1927 int out_x_origin, out_y_origin;
1928 int out_x, out_y;
1929
1930 // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
1931 for (int ob = 0; ob < out_batch; ob++)
1932 {
1933 for (int ih = 0; ih < in_height; ih++)
1934 {
1935 for (int iw = 0; iw < in_width; iw++)
1936 {
Eric Kunzec1a97832022-07-01 16:56:09 -07001937 out_x_origin = iw * stride_w + out_pad_left;
1938 out_y_origin = ih * stride_h + out_pad_top;
Eric Kunzee5e26762020-10-13 16:11:07 -07001939 for (int ic = 0; ic < in_channels; ic++)
1940 {
1941 for (int fh = 0; fh < f_height; fh++)
1942 {
1943 for (int fw = 0; fw < f_width; fw++)
1944 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001945 out_x = out_x_origin + fw;
1946 out_y = out_y_origin + fh;
Eric Kunzee5e26762020-10-13 16:11:07 -07001947 for (int oc = 0; oc < out_channels; oc++)
1948 {
1949 if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
1950 {
1951 this->output->getTensor()(ob, out_y, out_x, oc) +=
James Ward8b390432022-08-12 20:48:56 +01001952 (OutEigenType) ((AccEigenType)input_val(ob, ih, iw, ic) *
1953 (AccEigenType)weight_val(oc, fh, fw, ic));
Eric Kunzee5e26762020-10-13 16:11:07 -07001954 }
1955 }
1956 }
1957 }
1958 }
1959 }
1960 }
1961 }
1962
James Wardd34b3fc2023-01-18 14:51:25 +00001963 if (OutDtype == DType_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001964 {
James Ward8b390432022-08-12 20:48:56 +01001965 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1966 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001967 }
1968
1969 return GraphNode::eval();
1970}
1971
1972// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +01001973DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
James Ward24dbc422022-10-19 12:20:31 +01001974DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001975DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08001976DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07001977DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
1978
James Wardd34b3fc2023-01-18 14:51:25 +00001979DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
1980DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
1981DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
1982DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
1983DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
1984DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
Eric Kunzee5e26762020-10-13 16:11:07 -07001985
James Wardd34b3fc2023-01-18 14:51:25 +00001986 // [in_t, weight_t, out_t]
1987DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
1988DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
1989DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
1990DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
1991DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
1992DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
1993DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
Eric Kunzee5e26762020-10-13 16:11:07 -07001994
James Wardd34b3fc2023-01-18 14:51:25 +00001995DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
1996DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
1997DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
1998DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
1999DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
2000DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
2001DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
Kevin Cheng1533b852021-09-01 12:51:58 -07002002
James Wardd34b3fc2023-01-18 14:51:25 +00002003DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
2004DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
2005DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
2006DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
2007DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
2008DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
2009DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
Eric Kunzee5e26762020-10-13 16:11:07 -07002010
Luke Hutton57287132023-02-06 14:54:18 +00002011DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
2012
James Wardd34b3fc2023-01-18 14:51:25 +00002013DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
2014DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
2015DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
2016DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
2017DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
2018DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
2019DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
Eric Kunzee5e26762020-10-13 16:11:07 -07002020
James Wardd34b3fc2023-01-18 14:51:25 +00002021DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
2022DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
2023DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
2024DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
2025DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
2026DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -07002027
James Ward8b390432022-08-12 20:48:56 +01002028DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002029DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002030DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002031DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002032DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
2033
Luke Hutton261b7b62023-01-10 14:50:31 +00002034DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
2035
James Wardd34b3fc2023-01-18 14:51:25 +00002036DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
2037DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
2038DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
2039DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
2040DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
2041DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
2042DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);