blob: c617dda98d4023743a9df1d2c144448109b3d0aa [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tensor_ops.h"
17#include "quant_util.h"
18#include "template_types.h"
James Ward8b390432022-08-12 20:48:56 +010019#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070020
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
Kevin Cheng9fe17242021-11-10 01:04:39 +000025int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
26 std::vector<int32_t> input_shape,
27 std::vector<int32_t> output_shape,
28 std::string& msg)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000029{
TatWai Chong86c403b2022-06-06 20:46:01 -070030 if (attribute->pad().size() != 4)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000031 {
32 msg = "illegal size for attribute padding";
33 return 1;
34 }
35
36 if (attribute->kernel().size() != 2)
37 {
38 msg = "illegal size for attribute kernel";
39 return 1;
40 }
41
42 if (attribute->stride().size() != 2)
43 {
44 msg = "illegal size for attribute stride";
45 return 1;
46 }
47
TatWai Chong86c403b2022-06-06 20:46:01 -070048 for (int32_t i : attribute->pad())
Kevin Cheng7eb93d72021-10-09 01:26:08 +000049 {
50 if (i < 0)
51 {
52 msg = "At least one pad is smaller than zero";
53 return 1;
54 }
55 }
56
57 for (int32_t i : attribute->kernel())
58 {
59 if (i < 1)
60 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000061 msg = "At least one kernel dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000062 return 1;
63 }
64 }
65
66 for (int32_t i : attribute->stride())
67 {
68 if (i < 1)
69 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000070 msg = "At least one stride dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000071 return 1;
72 }
73 }
74
75 int32_t IH = input_shape[1];
76 int32_t IW = input_shape[2];
77 int32_t OH = output_shape[1];
78 int32_t OW = output_shape[2];
79
TatWai Chong86c403b2022-06-06 20:46:01 -070080 int32_t pad_top = attribute->pad()[0];
81 int32_t pad_bottom = attribute->pad()[1];
82 int32_t pad_left = attribute->pad()[2];
83 int32_t pad_right = attribute->pad()[3];
Kevin Cheng7eb93d72021-10-09 01:26:08 +000084
85 int32_t stride_y = attribute->stride()[0];
86 int32_t stride_x = attribute->stride()[1];
87 int32_t kernel_y = attribute->kernel()[0];
88 int32_t kernel_x = attribute->kernel()[1];
89
90 if (pad_top >= kernel_y || pad_bottom >= kernel_y || pad_left >= kernel_x || pad_right >= kernel_x)
91 {
92 msg = "At least one pad is >= kernel dimension";
93 return 1;
94 }
95
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010096 int32_t full_H = IH + pad_top + pad_bottom - kernel_y;
97 int32_t full_W = IW + pad_left + pad_right - kernel_x;
98
99 if ((full_H % stride_y != 0) ||
100 (full_W % stride_x != 0))
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000101 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100102 msg = "Parameters must yield exact integer output dimensions";
103 return 1;
104 }
105
106 if ((OH != (full_H / stride_y) + 1) ||
107 (OW != (full_W / stride_x) + 1))
108 {
109 msg = "Mismatch between output shape provided and expected output shape (" +
110 std::to_string((full_H / stride_y) + 1) + "," +
111 std::to_string((full_W / stride_x) + 1) + ")";
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000112 return 1;
113 }
114
115 return 0;
116}
117
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000118int check_conv_attribute(tosa::TosaConvAttribute* attribute,
Kevin Cheng9fe17242021-11-10 01:04:39 +0000119 uint32_t conv_dimension,
120 std::vector<int32_t> input_shape,
121 std::vector<int32_t> output_shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100122 std::vector<int32_t> weights,
123 uint32_t offset_kernel,
Kevin Cheng9fe17242021-11-10 01:04:39 +0000124 DType InDtype,
125 DType WeightDtype,
126 std::string& msg)
127{
TatWai Chong86c403b2022-06-06 20:46:01 -0700128 if (attribute->pad().size() != (2 * conv_dimension))
Kevin Cheng9fe17242021-11-10 01:04:39 +0000129 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700130 msg = "Illegal size for attribute pad";
Kevin Cheng9fe17242021-11-10 01:04:39 +0000131 return 1;
132 }
133
134 if (attribute->stride().size() != conv_dimension)
135 {
136 msg = "Illegal size for attribute stride";
137 return 1;
138 }
139
140 if (attribute->dilation().size() != conv_dimension)
141 {
142 msg = "Illegal size for attribute dilation";
143 return 1;
144 }
145
TatWai Chong86c403b2022-06-06 20:46:01 -0700146 for (int32_t i : attribute->pad())
Kevin Cheng9fe17242021-11-10 01:04:39 +0000147 {
148 if (i < 0)
149 {
150 msg = "At least one pad is smaller than zero";
151 return 1;
152 }
153 }
154
155 for (int32_t i : attribute->stride())
156 {
157 if (i < 1)
158 {
159 msg = "At least one stride dimension is smaller than one";
160 return 1;
161 }
162 }
163
164 for (int32_t i : attribute->dilation())
165 {
166 if (i < 1)
167 {
168 msg = "At least one dilation dimension is smaller than one";
169 return 1;
170 }
171 }
172
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100173 ASSERT_MSG(conv_dimension == 2 || conv_dimension == 3, "Unsupported convolution dimension")
174
TatWai Chongfd629052022-07-25 04:01:58 +0000175 int32_t offset_d = conv_dimension == 3 ? 1 : 0;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100176 int32_t ID = conv_dimension == 3 ? input_shape[1] : 1;
177 int32_t IH = input_shape[1 + offset_d];
178 int32_t IW = input_shape[2 + offset_d];
179 int32_t OD = conv_dimension == 3 ? output_shape[1] : 1;
180 int32_t OH = output_shape[1 + offset_d];
181 int32_t OW = output_shape[2 + offset_d];
182
183 int32_t stride_d = conv_dimension == 3 ? attribute->stride()[0] : 1;
184 int32_t stride_y = attribute->stride()[0 + offset_d];
185 int32_t stride_x = attribute->stride()[1 + offset_d];
186 int32_t kernel_d = conv_dimension == 3 ? weights[offset_kernel] : 1;
187 int32_t kernel_h = weights[offset_kernel + offset_d];
188 int32_t kernel_w = weights[offset_kernel + 1 + offset_d];
189 int32_t dilation_d = conv_dimension == 3 ? attribute->dilation()[0] : 1;
190 int32_t dilation_y = attribute->dilation()[0 + offset_d];
191 int32_t dilation_x = attribute->dilation()[1 + offset_d];
192
193 offset_d *= 2;
TatWai Chong86c403b2022-06-06 20:46:01 -0700194 int32_t pad_d0 = conv_dimension == 3 ? attribute->pad()[0] : 0;
195 int32_t pad_d1 = conv_dimension == 3 ? attribute->pad()[1] : 0;
196 int32_t pad_top = attribute->pad()[0 + offset_d];
197 int32_t pad_bottom = attribute->pad()[1 + offset_d];
198 int32_t pad_left = attribute->pad()[2 + offset_d];
199 int32_t pad_right = attribute->pad()[3 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100200
201 int32_t full_D = ID - 1 + pad_d0 + pad_d1 - (kernel_d - 1) * dilation_d;
202 int32_t full_H = IH - 1 + pad_top + pad_bottom - (kernel_h - 1) * dilation_y;
203 int32_t full_W = IW - 1 + pad_left + pad_right - (kernel_w - 1) * dilation_x;
204
205 if ((full_H % stride_y != 0) ||
206 (full_W % stride_x != 0) ||
207 (full_D % stride_d != 0))
208 {
209 msg = "Parameters must yield exact integer output dimensions";
210 return 1;
211 }
212
213 if ((OH != (full_H / stride_y) + 1) ||
214 (OW != (full_W / stride_x) + 1) ||
215 (OD != (full_D / stride_d) + 1))
216 {
217 std::string msg_d = "";
218 if (conv_dimension == 3)
219 {
220 msg_d += std::to_string((full_D / stride_d) + 1) + ",";
221 }
222 msg = "Mismatch between output shape provided and expected output shape (" +
223 msg_d +
224 std::to_string((full_H / stride_y) + 1) + "," +
225 std::to_string((full_W / stride_x) + 1) + ")";
226 return 1;
227 }
228
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000229 if (InDtype != DType_INT8 && attribute->input_zp() != 0) {
230 msg = "Input zero point must be zero for non-int8 data";
231 return 1;
232 }
233 if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) {
234 msg = "Weight zero point must be zero for non-int8 data";
235 return 1;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000236 }
237
238 return 0;
239}
240
Eric Kunzee5e26762020-10-13 16:11:07 -0700241template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700242OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
243 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700244 uint64_t id_)
245 : GraphNode(sgt_, Op_ARGMAX, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700246{
247 setRequiredOperands(1, 1);
Kevin Chengcc61be32021-10-14 17:09:57 -0700248 setRequiredRank(1, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700249
250 INIT_ATTRIBUTE(Axis);
251}
252
253template <int Rank, DType Dtype>
254OpArgMax<Rank, Dtype>::~OpArgMax()
255{
256 if (attribute)
257 delete attribute;
258}
259
260template <int Rank, DType Dtype>
261int OpArgMax<Rank, Dtype>::checkTensorAttributes()
262{
263 if (validateRequiredOperands())
264 return 1;
265
Kevin Chengcc61be32021-10-14 17:09:57 -0700266 if (validateRequiredRank(inputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700267 {
268 return 1;
269 }
270
Kevin Chengcc61be32021-10-14 17:09:57 -0700271 int32_t output_rank = inputs[0]->getRank() - 1;
272 if (output_rank != outputs[0]->getRank())
273 {
274 printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
275 return 1;
276 }
277
278 if (outputs[0]->getDtype() != DType_INT32)
279 {
280 printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
281 return 1;
282 }
283
Eric Kunzee5e26762020-10-13 16:11:07 -0700284 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
285 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
286
Kevin Chengcc61be32021-10-14 17:09:57 -0700287 if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
288 {
289 printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
290 return 1;
291 }
292
293 bool shape_check = true;
294 for (int32_t i = 0; i < input->getRank(); i++)
295 {
296 if (i < attribute->axis())
297 {
298 if (input->getShape()[i] != output->getShape()[i])
299 {
300 shape_check = false;
301 break;
302 }
303 }
304 else if (i > attribute->axis())
305 {
306 if (input->getShape()[i] != output->getShape()[i - 1])
307 {
308 shape_check = false;
309 break;
310 }
311 }
312 // No need to check i == axis
313 }
314 if (!shape_check)
315 {
316 printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
317 return 1;
318 }
319
Eric Kunzee5e26762020-10-13 16:11:07 -0700320 return 0;
321}
322
323template <int Rank, DType Dtype>
324int OpArgMax<Rank, Dtype>::eval()
325{
326 Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
327
328 this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
329
330 return GraphNode::eval();
331}
332
James Ward8b390432022-08-12 20:48:56 +0100333template <DType Dtype, DType AccDtype>
334OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700335 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700336 uint64_t id_)
337 : GraphNode(sgt_, Op_AVG_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700338{
339 setRequiredOperands(1, 1);
340 setRequiredRank(4);
341
Kevin Cheng93a16282021-08-31 16:14:03 -0700342 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -0700343}
344
James Ward8b390432022-08-12 20:48:56 +0100345template <DType Dtype, DType AccDtype>
346OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700347{
348 if (attribute)
349 delete attribute;
350}
351
James Ward8b390432022-08-12 20:48:56 +0100352template <DType Dtype, DType AccDtype>
353int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700354{
355 if (validateRequiredOperands())
356 return 1;
357
358 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
359 {
360 return 1;
361 }
362
363 if (inputs[0]->matchType(*outputs[0]))
364 {
365 printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
366 return 1;
367 }
368
369 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
370 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
371
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000372 ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
373 ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
Eric Kunzee5e26762020-10-13 16:11:07 -0700374
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000375 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000376 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700377 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000378 msg = "OpAvgPool2d: " + msg;
379 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700380 return 1;
381 }
382
383 return 0;
384}
385
Eric Kunze830add42022-01-25 22:56:46 -0800386// This calculates the number of padding elements used for each location along an axis
387// Average pooling only divides by the number of elements used, not including padding.
388// This function uses left/right, but is also used for vertical padding with top/bottom
James Ward8b390432022-08-12 20:48:56 +0100389template <DType Dtype, DType AccDtype>
390ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
Eric Kunzee5e26762020-10-13 16:11:07 -0700391{
392 ETensor1<int32_t> result(out_size);
393
Eric Kunzee5e26762020-10-13 16:11:07 -0700394 result.setConstant(kernel_size);
395
Eric Kunze830add42022-01-25 22:56:46 -0800396 // adjust divisors on the left side for padding
397 // We start at the leftmost output element, and remove pad_left - (index * stride) elements
398 // until we have no more padding being used
Eric Kunze67a91552022-02-02 11:27:21 -0800399 for(int index = 0; (index <= pad_left / stride) && (index < out_size); index++) {
Eric Kunze830add42022-01-25 22:56:46 -0800400 int32_t adjust = pad_left - (index * stride);
401 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700402 }
403
Eric Kunze830add42022-01-25 22:56:46 -0800404 // The process repeats on the right side. Padding starts taking effect as we
405 // near the rightmost input element. The first output element which touches
406 // padding is defined in the initialization of index below. Then we keep moving
407 // to the right, increasing padding until we get to the last output element.
408 int index = std::max(0, ((pad_left + in_size - kernel_size) / stride) + 1);
409 for (; index < out_size; index++) {
410 int32_t adjust = ((index * stride) + kernel_size) - (pad_left + in_size);
411 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700412 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700413 return result;
414}
415
416// assuming input and output tensor have same scales like tflite reference
417// so no need to scale input and output
James Ward8b390432022-08-12 20:48:56 +0100418template <DType Dtype, DType AccDtype>
419int OpAvgPool2d<Dtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700420{
421 int in_batch = this->in->getShape()[0];
422 int in_height = this->in->getShape()[1];
423 int in_width = this->in->getShape()[2];
424 int in_channels = this->in->getShape()[3];
425
426 int out_batch = this->out->getShape()[0];
427 int out_height = this->out->getShape()[1];
428 int out_width = this->out->getShape()[2];
429 int out_channels = this->out->getShape()[3];
430
Kevin Chengacb550f2021-06-29 15:32:19 -0700431 ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
432 ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700433
TatWai Chong86c403b2022-06-06 20:46:01 -0700434 int pad_top = this->attribute->pad()[0];
435 int pad_bottom = this->attribute->pad()[1];
436 int pad_left = this->attribute->pad()[2];
437 int pad_right = this->attribute->pad()[3];
Eric Kunzee5e26762020-10-13 16:11:07 -0700438 int kernel_h = this->attribute->kernel()[0];
439 int kernel_w = this->attribute->kernel()[1];
440 int stride_h = this->attribute->stride()[0];
441 int stride_w = this->attribute->stride()[1];
442
James Ward8b390432022-08-12 20:48:56 +0100443 tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
444
Eric Kunzee5e26762020-10-13 16:11:07 -0700445 DEBUG_INFO(OP,
446 "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
James Ward8b390432022-08-12 20:48:56 +0100447 "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
Eric Kunzee5e26762020-10-13 16:11:07 -0700448 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
James Ward8b390432022-08-12 20:48:56 +0100449 kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700450
451 Eigen::array<Eigen::Index, 2> im2col_input_dims;
452 im2col_input_dims[0] = kernel_h * kernel_w;
453 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
454
455 Eigen::array<Eigen::Index, 4> col2im_output_dims;
456 col2im_output_dims[0] = out_batch;
457 col2im_output_dims[1] = out_height;
458 col2im_output_dims[2] = out_width;
459 col2im_output_dims[3] = out_channels;
460
TatWai Chong86c403b2022-06-06 20:46:01 -0700461 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
462 pad[0] = std::make_pair(0, 0);
463 pad[1] = std::make_pair(pad_top, pad_bottom);
464 pad[2] = std::make_pair(pad_left, pad_right);
465 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700466
467 ETensor4<InEigenType> input_val = this->in->getTensor();
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000468 if (Dtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700469 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000470 input_val = input_val - (InEigenType)attribute->input_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700471 }
472
TatWai Chong86c403b2022-06-06 20:46:01 -0700473 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700474
475 // assuming input and output have same scales
476 // so input and output scaling is not required
477 // TODO: check if this assumption TOSA made
478
479 // extract_image_patches() output [N, KH, KW, H * W, C]
480 // transpose to [KH, KW, N, H * W, C]
481 // reshape to [KH * KW, N * H * W * C]
482 ETensor2<InEigenType> input_extract_patches =
483 input_padded.extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID)
484 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
485 .reshape(im2col_input_dims);
486
487 // 1D result with [N * H * W * C]
488 ETensor1<AccEigenType> out_1d(this->out->getElementCount());
489 out_1d.setZero();
490
491 // sum pool
492 for (size_t i = 0; i < this->out->getElementCount(); i++)
493 {
494 for (int32_t j = 0; j < kernel_h * kernel_w; j++)
495 {
496 out_1d(i) += (AccEigenType)input_extract_patches(j, i);
497 }
498 }
499
500 // reshape result to [N, H, W, C] and divide with div_map
501 ETensor4<AccEigenType> sum = out_1d.reshape(col2im_output_dims);
502
503 // calculate 1d height/width div_map (number of elements this pooling window covers)
504 // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C]
TatWai Chong86c403b2022-06-06 20:46:01 -0700505 ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h, pad_top, pad_bottom);
506 ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700507 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
508 Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
509
510 ETensor4<int32_t> div_map =
511 div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 })
512 .contract(div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width }), contract_dims)
513 .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
514 .broadcast(bcast);
James Ward8b390432022-08-12 20:48:56 +0100515 if (Dtype != DType_FLOAT && Dtype != DType_FP16)
Eric Kunzee5e26762020-10-13 16:11:07 -0700516 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700517 try
518 {
519 this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
520 int32_t multiplier, shift;
521 TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift);
Eric Kunzee5e26762020-10-13 16:11:07 -0700522
Kevin Chengacb550f2021-06-29 15:32:19 -0700523 return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false);
524 });
525 }
526 catch (std::string desc)
527 {
528 REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
529 }
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000530 this->out->getTensor() = this->out->getTensor() + (OutEigenType)(attribute->output_zp());
Eric Kunzee5e26762020-10-13 16:11:07 -0700531 this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
532 this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
533 }
534 else
535 {
James Ward8b390432022-08-12 20:48:56 +0100536 // Case for float-type resizes
Eric Kunzee5e26762020-10-13 16:11:07 -0700537 this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
538 }
539
540 return GraphNode::eval();
541}
542
James Ward8b390432022-08-12 20:48:56 +0100543template <DType InDtype, DType WeightDtype, DType AccDtype>
544OpConv2d<InDtype, WeightDtype, AccDtype>::OpConv2d(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700545 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700546 uint64_t id_)
547 : GraphNode(sgt_, Op_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700548{
549 setRequiredOperands(3, 1);
550 setRequiredRank(4);
551
Kevin Cheng93a16282021-08-31 16:14:03 -0700552 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700553}
554
James Ward8b390432022-08-12 20:48:56 +0100555template <DType InDtype, DType WeightDtype, DType AccDtype>
556OpConv2d<InDtype, WeightDtype, AccDtype>::~OpConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700557{
558 if (attribute)
559 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700560}
561
James Ward8b390432022-08-12 20:48:56 +0100562template <DType InDtype, DType WeightDtype, DType AccDtype>
563int OpConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700564{
565 if (validateRequiredOperands())
566 return 1;
567
568 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
569 {
570 return 1;
571 }
572
573 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
574 if (inputs[2]->getRank() != 1)
575 {
576 printNodeValidationError("OpConv2d: bias tensor must be rank 1");
577 }
578
Kevin Chengcc61be32021-10-14 17:09:57 -0700579 ERROR_IF(outputs[0]->getDtype() != AccDtype,
James Ward8b390432022-08-12 20:48:56 +0100580 "OpConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700581
Eric Kunzee5e26762020-10-13 16:11:07 -0700582 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
583 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
584 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100585 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700586
Kevin Cheng9fe17242021-11-10 01:04:39 +0000587 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000588 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100589 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700590 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000591 msg = "OpConv2d: " + msg;
592 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700593 return 1;
594 }
595
Eric Kunzee5e26762020-10-13 16:11:07 -0700596 return 0;
597}
598
James Ward8b390432022-08-12 20:48:56 +0100599template <DType InDtype, DType WeightDtype, DType AccDtype>
600int OpConv2d<InDtype, WeightDtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700601{
602 int in_batch = this->input->getShape()[0];
603 int in_height = this->input->getShape()[1];
604 int in_width = this->input->getShape()[2];
605 int in_channels = this->input->getShape()[3];
606
607 int f_out_channels = this->weight->getShape()[0];
608 int f_height = this->weight->getShape()[1];
609 int f_width = this->weight->getShape()[2];
610 int f_in_channels = this->weight->getShape()[3];
611
612 int b_out_channels = this->bias->getShape()[0];
613
614 int out_batch = this->output->getShape()[0];
615 int out_height = this->output->getShape()[1];
616 int out_width = this->output->getShape()[2];
617 int out_channels = this->output->getShape()[3];
618
Kevin Chengacb550f2021-06-29 15:32:19 -0700619 ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
620 ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
621 in_channels);
622 ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
623 out_channels);
624 ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700625
TatWai Chong86c403b2022-06-06 20:46:01 -0700626 int pad_top = this->attribute->pad()[0];
627 int pad_bottom = this->attribute->pad()[1];
628 int pad_left = this->attribute->pad()[2];
629 int pad_right = this->attribute->pad()[3];
630
Eric Kunzee5e26762020-10-13 16:11:07 -0700631 int stride_h = this->attribute->stride()[0];
632 int stride_w = this->attribute->stride()[1];
633 int dilation_h = this->attribute->dilation()[0];
634 int dilation_w = this->attribute->dilation()[1];
635
James Ward8b390432022-08-12 20:48:56 +0100636 tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
637
Eric Kunzee5e26762020-10-13 16:11:07 -0700638 DEBUG_INFO(OP,
639 "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
James Ward8b390432022-08-12 20:48:56 +0100640 "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
Eric Kunzee5e26762020-10-13 16:11:07 -0700641 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
TatWai Chong86c403b2022-06-06 20:46:01 -0700642 out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
James Ward8b390432022-08-12 20:48:56 +0100643 pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700644
645 // GEMM-conv2d, left matrix is input, right matrix is weight
646 Eigen::array<Eigen::Index, 2> im2col_input_dims;
647 im2col_input_dims[0] = out_batch * out_height * out_width;
648 im2col_input_dims[1] = f_height * f_width * f_in_channels;
649
650 Eigen::array<Eigen::Index, 2> im2col_weight_dims;
651 im2col_weight_dims[0] = f_height * f_width * f_in_channels;
652 im2col_weight_dims[1] = f_out_channels;
653
654 Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
655 bias_reshaped_dims[0] = 1;
656 bias_reshaped_dims[1] = b_out_channels;
657
658 Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
659 weight_zp_bcast_dims[0] = f_height;
660 weight_zp_bcast_dims[1] = f_width;
661 weight_zp_bcast_dims[2] = f_in_channels;
662
663 Eigen::array<Eigen::Index, 2> bias_bcast_dims;
664 bias_bcast_dims[0] = out_batch * out_height * out_width;
665 bias_bcast_dims[1] = 1;
666
667 Eigen::array<Eigen::Index, 4> col2im_output_dims;
668 col2im_output_dims[0] = out_batch;
669 col2im_output_dims[1] = out_height;
670 col2im_output_dims[2] = out_width;
671 col2im_output_dims[3] = out_channels;
672
673 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
674
TatWai Chong86c403b2022-06-06 20:46:01 -0700675 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
676 pad[0] = std::make_pair(0, 0);
677 pad[1] = std::make_pair(pad_top, pad_bottom);
678 pad[2] = std::make_pair(pad_left, pad_right);
679 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700680
681 TIn input_val = this->input->getTensor();
682 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -0700683 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700684 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000685 input_val = input_val - (InEigenType)attribute->input_zp();
686 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700687 }
688
TatWai Chong86c403b2022-06-06 20:46:01 -0700689 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700690
691 // extract_image_patches() output [N, KH, KW, H * W, C]
692 // need to transpose to [N, H * W, KH, KW, C]
693 ETensor5<InEigenType> input_extract_patches =
694 input_padded
695 .extract_image_patches(f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID)
696 .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
697
698 // reshape input to [N * H * W, KH * KW * C]
699 ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
700
701 // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
702 ETensor2<WeightEigenType> im2col_weight =
James Ward8b390432022-08-12 20:48:56 +0100703 weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
Eric Kunzee5e26762020-10-13 16:11:07 -0700704
705 // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
706 // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
James Ward8b390432022-08-12 20:48:56 +0100707 ETensor2<OutEigenType> bias_2d = (this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700708
709 // output matrix is [N * H * W, C]
James Ward8b390432022-08-12 20:48:56 +0100710 ETensor2<OutEigenType> contracted_result =
711 (im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700712
713 // adding bias
James Ward8b390432022-08-12 20:48:56 +0100714 ETensor2<OutEigenType> biased_output = contracted_result + bias_2d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700715
716 // reshape back to [N, H, W, C]
717 this->output->getTensor() = biased_output.reshape(col2im_output_dims);
718
719 if (AccDtype == DType_INT48)
720 {
James Ward8b390432022-08-12 20:48:56 +0100721 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
722 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700723 }
724
725 return GraphNode::eval();
726}
727
James Ward8b390432022-08-12 20:48:56 +0100728template <DType InDtype, DType WeightDtype, DType AccDtype>
729OpConv3d<InDtype, WeightDtype, AccDtype>::OpConv3d(SubgraphTraverser* sgt_,
Kevin Cheng1533b852021-09-01 12:51:58 -0700730 TosaAttributeBase* attribute_,
Kevin Cheng1533b852021-09-01 12:51:58 -0700731 uint64_t id_)
732 : GraphNode(sgt_, Op_CONV3D, id_)
733{
734 setRequiredOperands(3, 1);
735 setRequiredRank(5);
736
737 INIT_ATTRIBUTE(Conv);
Kevin Cheng1533b852021-09-01 12:51:58 -0700738}
739
James Ward8b390432022-08-12 20:48:56 +0100740template <DType InDtype, DType WeightDtype, DType AccDtype>
741OpConv3d<InDtype, WeightDtype, AccDtype>::~OpConv3d()
Kevin Cheng1533b852021-09-01 12:51:58 -0700742{
743 if (attribute)
744 delete attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700745}
746
James Ward8b390432022-08-12 20:48:56 +0100747template <DType InDtype, DType WeightDtype, DType AccDtype>
748int OpConv3d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
Kevin Cheng1533b852021-09-01 12:51:58 -0700749{
750 if (validateRequiredOperands())
751 return 1;
752
753 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
754 {
755 return 1;
756 }
757
758 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
759 if (inputs[2]->getRank() != 1)
760 {
761 printNodeValidationError("OpConv3d: bias tensor must be rank 1");
762 }
763
Kevin Chengcc61be32021-10-14 17:09:57 -0700764 ERROR_IF(outputs[0]->getDtype() != AccDtype,
James Ward8b390432022-08-12 20:48:56 +0100765 "OpConv3d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700766
Kevin Cheng1533b852021-09-01 12:51:58 -0700767 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
768 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
769 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100770 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng1533b852021-09-01 12:51:58 -0700771
Kevin Cheng9fe17242021-11-10 01:04:39 +0000772 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000773 if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100774 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Kevin Cheng1533b852021-09-01 12:51:58 -0700775 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000776 msg = "OpConv3d: " + msg;
777 printNodeValidationError(msg.c_str());
Kevin Cheng1533b852021-09-01 12:51:58 -0700778 return 1;
779 }
780
Kevin Cheng1533b852021-09-01 12:51:58 -0700781 return 0;
782}
783
James Ward8b390432022-08-12 20:48:56 +0100784template <DType InDtype, DType WeightDtype, DType AccDtype>
785int OpConv3d<InDtype, WeightDtype, AccDtype>::eval()
Kevin Cheng1533b852021-09-01 12:51:58 -0700786{
787 int in_batch = this->input->getShape()[0];
788 int in_depth = this->input->getShape()[1];
789 int in_height = this->input->getShape()[2];
790 int in_width = this->input->getShape()[3];
791 int in_channels = this->input->getShape()[4];
792
793 int f_out_channels = this->weight->getShape()[0];
794 int f_depth = this->weight->getShape()[1];
795 int f_height = this->weight->getShape()[2];
796 int f_width = this->weight->getShape()[3];
797 int f_in_channels = this->weight->getShape()[4];
798
799 int b_out_channels = this->bias->getShape()[0];
800
801 int out_batch = this->output->getShape()[0];
802 int out_depth = this->output->getShape()[1];
803 int out_height = this->output->getShape()[2];
804 int out_width = this->output->getShape()[3];
805 int out_channels = this->output->getShape()[4];
806
807 ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
808 ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
809 in_channels);
810 ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
811 out_channels);
812 ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels);
813
TatWai Chong86c403b2022-06-06 20:46:01 -0700814 int pad_d0 = this->attribute->pad()[0];
815 int pad_d1 = this->attribute->pad()[1];
816 int pad_top = this->attribute->pad()[2];
817 int pad_bottom = this->attribute->pad()[3];
818 int pad_left = this->attribute->pad()[4];
819 int pad_right = this->attribute->pad()[5];
820
Kevin Cheng1533b852021-09-01 12:51:58 -0700821 int stride_d = this->attribute->stride()[0];
822 int stride_h = this->attribute->stride()[1];
823 int stride_w = this->attribute->stride()[2];
TatWai Chong86c403b2022-06-06 20:46:01 -0700824
Kevin Cheng1533b852021-09-01 12:51:58 -0700825 int dilation_d = this->attribute->dilation()[0];
826 int dilation_h = this->attribute->dilation()[1];
827 int dilation_w = this->attribute->dilation()[2];
828
James Ward8b390432022-08-12 20:48:56 +0100829 tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
830
Kevin Cheng1533b852021-09-01 12:51:58 -0700831 DEBUG_INFO(
832 OP,
833 "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
James Ward8b390432022-08-12 20:48:56 +0100834 "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d], accum_dtype=%s",
Kevin Cheng1533b852021-09-01 12:51:58 -0700835 in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
836 out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h,
James Ward8b390432022-08-12 20:48:56 +0100837 dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Kevin Cheng1533b852021-09-01 12:51:58 -0700838
TatWai Chong86c403b2022-06-06 20:46:01 -0700839 Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
840 pad[0] = std::make_pair(0, 0);
841 pad[1] = std::make_pair(pad_d0, pad_d1);
842 pad[2] = std::make_pair(pad_top, pad_bottom);
843 pad[3] = std::make_pair(pad_left, pad_right);
844 pad[4] = std::make_pair(0, 0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700845
846 TIn input_val = this->input->getTensor();
847 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -0700848 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Kevin Cheng1533b852021-09-01 12:51:58 -0700849 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000850 input_val = input_val - (InEigenType)attribute->input_zp();
851 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Kevin Cheng1533b852021-09-01 12:51:58 -0700852 }
853
TatWai Chong86c403b2022-06-06 20:46:01 -0700854 ETensor5<InEigenType> input_padded = input_val.pad(pad);
Kevin Cheng1533b852021-09-01 12:51:58 -0700855
856 // 1. initialize with bias
857 Eigen::array<Eigen::Index, 5> reshape_dim;
858 reshape_dim.fill(1);
859 reshape_dim[4] = b_out_channels;
860
861 Eigen::array<Eigen::Index, 5> bcast;
862 bcast[0] = out_batch;
863 bcast[1] = out_depth;
864 bcast[2] = out_height;
865 bcast[3] = out_width;
866 bcast[4] = 1;
867 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
868
869 // 2. direct convolution
James Ward8b390432022-08-12 20:48:56 +0100870 AccEigenType acc(0.0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700871 int d_idx, h_idx, w_idx;
872
873 for (int ob = 0; ob < out_batch; ob++)
874 {
875 for (int od = 0; od < out_depth; od++)
876 {
877 for (int oh = 0; oh < out_height; oh++)
878 {
879 for (int ow = 0; ow < out_width; ow++)
880 {
881 for (int oc = 0; oc < out_channels; oc++)
882 {
Eric Kunze7edb34c2022-05-16 17:34:40 -0700883 // Initialize accumulator with bias value
James Ward8b390432022-08-12 20:48:56 +0100884 acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc);
Kevin Cheng1533b852021-09-01 12:51:58 -0700885 for (int fd = 0; fd < f_depth; fd++)
886 {
887 d_idx = od * stride_d + fd * dilation_d;
888 for (int fh = 0; fh < f_height; fh++)
889 {
890 h_idx = oh * stride_h + fh * dilation_h;
891 for (int fw = 0; fw < f_width; fw++)
892 {
893 w_idx = ow * stride_w + fw * dilation_w;
894 for (int ic = 0; ic < in_channels; ic++)
895 {
896 acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
897 (AccEigenType)weight_val(oc, fd, fh, fw, ic));
898 }
899 }
900 }
901 }
James Ward8b390432022-08-12 20:48:56 +0100902 this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc;
Kevin Cheng1533b852021-09-01 12:51:58 -0700903 }
904 }
905 }
906 }
907 }
908
909 if (AccDtype == DType_INT48)
910 {
James Ward8b390432022-08-12 20:48:56 +0100911 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
912 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Kevin Cheng1533b852021-09-01 12:51:58 -0700913 }
914
915 return GraphNode::eval();
916}
917
James Ward8b390432022-08-12 20:48:56 +0100918template <DType InDtype, DType WeightDtype, DType AccDtype>
919OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700920 TosaAttributeBase* attribute_,
Eric Kunzee5e26762020-10-13 16:11:07 -0700921 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700922 : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700923{
924 setRequiredOperands(3, 1);
925 setRequiredRank(4);
926
Kevin Cheng93a16282021-08-31 16:14:03 -0700927 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700928}
929
James Ward8b390432022-08-12 20:48:56 +0100930template <DType InDtype, DType WeightDtype, DType AccDtype>
931OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::~OpDepthwiseConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700932{
933 if (attribute)
934 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700935}
936
James Ward8b390432022-08-12 20:48:56 +0100937template <DType InDtype, DType WeightDtype, DType AccDtype>
938int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700939{
940 if (validateRequiredOperands())
941 return 1;
942
943 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
944 {
945 return 1;
946 }
947
948 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
949 if (inputs[2]->getRank() != 1)
950 {
951 printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
952 }
953
Kevin Chengcc61be32021-10-14 17:09:57 -0700954 ERROR_IF(outputs[0]->getDtype() != AccDtype,
James Ward8b390432022-08-12 20:48:56 +0100955 "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700956
Eric Kunzee5e26762020-10-13 16:11:07 -0700957 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
958 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
959 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100960 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700961
Kevin Cheng9fe17242021-11-10 01:04:39 +0000962 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000963 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100964 weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700965 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000966 msg = "OpDepthwiseConv2d: " + msg;
967 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700968 return 1;
969 }
970
Eric Kunzee5e26762020-10-13 16:11:07 -0700971 return 0;
972}
973
James Ward8b390432022-08-12 20:48:56 +0100974template <DType InDtype, DType WeightDtype, DType AccDtype>
975int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700976{
977 int in_batch = this->input->getShape()[0];
978 int in_height = this->input->getShape()[1];
979 int in_width = this->input->getShape()[2];
980 int in_channels = this->input->getShape()[3];
981
982 int f_height = this->weight->getShape()[0];
983 int f_width = this->weight->getShape()[1];
984 int f_in_channels = this->weight->getShape()[2];
985 int f_multiplier = this->weight->getShape()[3];
986
987 int b_out_channels = this->bias->getShape()[0];
988
989 int out_batch = this->output->getShape()[0];
990 int out_height = this->output->getShape()[1];
991 int out_width = this->output->getShape()[2];
992 int out_channels = this->output->getShape()[3];
993
Kevin Chengacb550f2021-06-29 15:32:19 -0700994 ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
995 ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels,
996 in_channels);
997 ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
998 in_channels * f_multiplier, out_channels);
999 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
1000 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001001
TatWai Chong86c403b2022-06-06 20:46:01 -07001002 int pad_top = this->attribute->pad()[0];
1003 int pad_bottom = this->attribute->pad()[1];
1004 int pad_left = this->attribute->pad()[2];
1005 int pad_right = this->attribute->pad()[3];
1006
Eric Kunzee5e26762020-10-13 16:11:07 -07001007 int stride_h = this->attribute->stride()[0];
1008 int stride_w = this->attribute->stride()[1];
1009 int dilation_h = this->attribute->dilation()[0];
1010 int dilation_w = this->attribute->dilation()[1];
1011
James Ward8b390432022-08-12 20:48:56 +01001012 tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
1013
Eric Kunzee5e26762020-10-13 16:11:07 -07001014 DEBUG_INFO(OP,
1015 "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Ward8b390432022-08-12 20:48:56 +01001016 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
Eric Kunzee5e26762020-10-13 16:11:07 -07001017 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
TatWai Chong86c403b2022-06-06 20:46:01 -07001018 out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
James Ward8b390432022-08-12 20:48:56 +01001019 pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001020
TatWai Chong86c403b2022-06-06 20:46:01 -07001021 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1022 pad[0] = std::make_pair(0, 0);
1023 pad[1] = std::make_pair(pad_top, pad_bottom);
1024 pad[2] = std::make_pair(pad_left, pad_right);
1025 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001026
1027 TIn input_val = this->input->getTensor();
1028 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -07001029 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001030 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001031 input_val = input_val - (InEigenType)attribute->input_zp();
1032 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001033 }
1034
TatWai Chong86c403b2022-06-06 20:46:01 -07001035 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -07001036
1037 // GEMM doesn't fit well with DepthwiseConv2d
TatWai Chong86c403b2022-06-06 20:46:01 -07001038 // 1. use extract_image_patches() to handle stride/dilation/pad
Eric Kunzee5e26762020-10-13 16:11:07 -07001039 // 2. perform direct convolution
1040
1041 // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
1042 ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
1043 f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID);
1044
1045 Eigen::array<Eigen::Index, 4> reshape_dim;
1046 reshape_dim.fill(1);
1047 reshape_dim[3] = b_out_channels;
1048
1049 Eigen::array<Eigen::Index, 4> bcast;
1050 bcast[0] = out_batch;
1051 bcast[1] = out_height;
1052 bcast[2] = out_width;
1053 bcast[3] = 1;
1054
1055 // initialize with bias
1056 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
1057
1058 // 2. direct depthwise convolution
1059 for (int ob = 0; ob < out_batch; ob++)
1060 {
1061 for (int oh = 0; oh < out_height; oh++)
1062 {
1063 for (int ow = 0; ow < out_width; ow++)
1064 {
1065 for (int ic = 0; ic < in_channels; ic++)
1066 {
1067 for (int cm = 0; cm < f_multiplier; cm++)
1068 {
1069 for (int fh = 0; fh < f_height; fh++)
1070 {
1071 for (int fw = 0; fw < f_width; fw++)
1072 {
James Ward8b390432022-08-12 20:48:56 +01001073 // Perform multiplication in AccEigenType then cast to OutEigenType
Eric Kunzee5e26762020-10-13 16:11:07 -07001074 this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
James Ward8b390432022-08-12 20:48:56 +01001075 (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
1076 (AccEigenType)weight_val(fh, fw, ic, cm));
Eric Kunzee5e26762020-10-13 16:11:07 -07001077 }
1078 }
1079 }
1080 }
1081 }
1082 }
1083 }
1084
1085 if (AccDtype == DType_INT48)
1086 {
James Ward8b390432022-08-12 20:48:56 +01001087 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1088 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001089 }
1090
1091 return GraphNode::eval();
1092}
1093
James Ward8b390432022-08-12 20:48:56 +01001094template <DType InDtype, DType WeightDtype, DType AccDtype>
1095OpFullyConnected<InDtype, WeightDtype, AccDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001096 TosaAttributeBase* attribute_,
Eric Kunzee5e26762020-10-13 16:11:07 -07001097 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001098 : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001099{
1100 setRequiredOperands(3, 1);
1101 setRequiredRank(2);
1102
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001103 INIT_ATTRIBUTE(FullyConnected);
Eric Kunzee5e26762020-10-13 16:11:07 -07001104}
1105
James Ward8b390432022-08-12 20:48:56 +01001106template <DType InDtype, DType WeightDtype, DType AccDtype>
1107OpFullyConnected<InDtype, WeightDtype, AccDtype>::~OpFullyConnected()
Eric Kunzee5e26762020-10-13 16:11:07 -07001108{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001109 if (attribute)
1110 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001111}
1112
James Ward8b390432022-08-12 20:48:56 +01001113template <DType InDtype, DType WeightDtype, DType AccDtype>
1114int OpFullyConnected<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001115{
1116 if (validateRequiredOperands())
1117 return 1;
1118
1119 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1120 {
1121 return 1;
1122 }
1123
1124 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1125 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1126 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
1127
1128 if (input->getShape()[1] != weight->getShape()[1])
1129 {
1130 printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
1131 return 1;
1132 }
1133
1134 if (weight->getShape()[0] != bias->getShape()[0])
1135 {
1136 printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
1137 return 1;
1138 }
1139
Kevin Chengcc61be32021-10-14 17:09:57 -07001140 ERROR_IF(outputs[0]->getDtype() != AccDtype,
James Ward8b390432022-08-12 20:48:56 +01001141 "OpFullyConnected: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001142
James Ward8b390432022-08-12 20:48:56 +01001143 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001144
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001145 ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
1146 ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001147
Eric Kunzee5e26762020-10-13 16:11:07 -07001148 return 0;
1149}
1150
James Ward8b390432022-08-12 20:48:56 +01001151template <DType InDtype, DType WeightDtype, DType AccDtype>
1152int OpFullyConnected<InDtype, WeightDtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001153{
1154 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1155 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1156
1157 Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
1158
1159 Eigen::array<Eigen::Index, 2> bias_reshape;
1160 bias_reshape[0] = 1;
1161 bias_reshape[1] = this->bias->getShape()[0];
1162
1163 Eigen::array<Eigen::Index, 2> bias_bcast;
1164 bias_bcast[0] = this->input->getShape()[0];
1165 bias_bcast[1] = 1;
1166
1167 TIn input_val = this->input->getTensor();
1168 TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
Eric Kunzef7337832022-06-17 08:19:12 -07001169 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001170 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001171 input_val = input_val - (InEigenType)attribute->input_zp();
1172 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001173 }
1174
1175 this->output->getTensor() =
James Ward8b390432022-08-12 20:48:56 +01001176 input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() +
1177 this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001178
1179 if (AccDtype == DType_INT48)
1180 {
James Ward8b390432022-08-12 20:48:56 +01001181 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1182 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001183 }
1184 return GraphNode::eval();
1185}
1186
James Ward8b390432022-08-12 20:48:56 +01001187template <DType Dtype, DType AccDtype>
1188OpMatMul<Dtype, AccDtype>::OpMatMul(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001189 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001190 uint64_t id_)
1191 : GraphNode(sgt_, Op_MATMUL, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001192{
1193 setRequiredOperands(2, 1);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001194 setRequiredRank(3);
Eric Kunzee5e26762020-10-13 16:11:07 -07001195
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001196 INIT_ATTRIBUTE(MatMul);
Eric Kunzee5e26762020-10-13 16:11:07 -07001197}
1198
James Ward8b390432022-08-12 20:48:56 +01001199template <DType Dtype, DType AccDtype>
1200OpMatMul<Dtype, AccDtype>::~OpMatMul()
Eric Kunzee5e26762020-10-13 16:11:07 -07001201{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001202 if (attribute)
1203 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001204}
1205
James Ward8b390432022-08-12 20:48:56 +01001206template <DType Dtype, DType AccDtype>
1207int OpMatMul<Dtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001208{
1209 if (validateRequiredOperands())
1210 return 1;
1211
1212 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1213 {
1214 return 1;
1215 }
1216
Kevin Chengcc61be32021-10-14 17:09:57 -07001217 ERROR_IF(outputs[0]->getDtype() != AccDtype,
James Ward8b390432022-08-12 20:48:56 +01001218 "OpMatMul: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001219
Kevin Cheng2d60f002021-06-09 14:18:32 -07001220 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1221 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
James Ward8b390432022-08-12 20:48:56 +01001222 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001223
Kevin Cheng2d60f002021-06-09 14:18:32 -07001224 ASSERT_MEM(a && b && output);
1225
1226 // a: [N, H, C]
1227 // b: [N, C, W]
1228 // c: [N, H, W]
1229
1230 // Check N
1231 if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001232 {
Kevin Cheng2d60f002021-06-09 14:18:32 -07001233 printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
Eric Kunzee5e26762020-10-13 16:11:07 -07001234 return 1;
1235 }
Kevin Cheng2d60f002021-06-09 14:18:32 -07001236 N = a->getShape()[0];
Eric Kunzee5e26762020-10-13 16:11:07 -07001237
Kevin Cheng2d60f002021-06-09 14:18:32 -07001238 // Check C
1239 if (a->getShape()[2] != b->getShape()[1])
1240 {
1241 printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
1242 return 1;
1243 }
1244 C = a->getShape()[2];
1245
1246 // Check H
1247 if (a->getShape()[1] != output->getShape()[1])
1248 {
1249 printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
1250 return 1;
1251 }
1252 H = a->getShape()[1];
1253
1254 // Check W
1255 if (b->getShape()[2] != output->getShape()[2])
1256 {
1257 printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
1258 return 1;
1259 }
1260 W = b->getShape()[2];
Eric Kunzee5e26762020-10-13 16:11:07 -07001261
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001262 ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data");
1263 ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001264
Eric Kunzee5e26762020-10-13 16:11:07 -07001265 return 0;
1266}
1267
James Ward8b390432022-08-12 20:48:56 +01001268template <DType Dtype, DType AccDtype>
1269int OpMatMul<Dtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001270{
1271 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1272 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1273
1274 TIn a_val = this->a->getTensor();
1275 TIn b_val = this->b->getTensor();
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001276 if (Dtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001277 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001278 a_val = a_val - (InEigenType)attribute->a_zp();
1279 b_val = b_val - (InEigenType)attribute->b_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001280 }
1281
Kevin Cheng2d60f002021-06-09 14:18:32 -07001282 Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
1283 Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
1284 Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
1285
1286 Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
1287 Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
1288
1289 Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
1290 Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
1291
1292 // Iterate N dimension.
1293 for (int i = 0; i < N; i++)
1294 {
1295 a_begin_array[0] = i;
1296 b_begin_array[0] = i;
1297
1298 TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
1299 TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
1300 TAccRank2 output_rank2_val =
1301 a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
James Ward8b390432022-08-12 20:48:56 +01001302 TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast<OutEigenType>();
Kevin Cheng2d60f002021-06-09 14:18:32 -07001303 if (i == 0)
1304 {
1305 this->output->getTensor() = output_rank3_val;
1306 }
1307 else
1308 {
James Ward8b390432022-08-12 20:48:56 +01001309 TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001310 this->output->getTensor() = temp;
1311 }
1312 }
Eric Kunzee5e26762020-10-13 16:11:07 -07001313
1314 if (AccDtype == DType_INT48)
1315 {
James Ward8b390432022-08-12 20:48:56 +01001316 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1317 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001318 }
1319
1320 return GraphNode::eval();
1321}
1322
1323template <DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -07001324OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
1325 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001326 uint64_t id_)
1327 : GraphNode(sgt_, Op_MAX_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001328{
1329 setRequiredOperands(1, 1);
1330 setRequiredRank(4);
1331
Kevin Cheng93a16282021-08-31 16:14:03 -07001332 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -07001333}
1334
1335template <DType Dtype>
1336OpMaxPool2d<Dtype>::~OpMaxPool2d()
1337{
1338 if (attribute)
1339 delete attribute;
1340}
1341
1342template <DType Dtype>
1343int OpMaxPool2d<Dtype>::checkTensorAttributes()
1344{
1345 if (validateRequiredOperands())
1346 return 1;
1347
1348 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
1349 {
1350 return 1;
1351 }
1352
1353 if (inputs[0]->matchType(*outputs[0]))
1354 {
1355 printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
1356 return 1;
1357 }
1358
1359 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1360 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1361
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001362 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +00001363 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001364 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001365 msg = "OpMaxPool2d: " + msg;
1366 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001367 return 1;
1368 }
1369
1370 return 0;
1371}
1372
1373template <DType Dtype>
1374int OpMaxPool2d<Dtype>::eval()
1375{
1376 int in_batch = this->in->getShape()[0];
1377 int in_height = this->in->getShape()[1];
1378 int in_width = this->in->getShape()[2];
1379 int in_channels = this->in->getShape()[3];
1380
1381 int out_batch = this->out->getShape()[0];
1382 int out_height = this->out->getShape()[1];
1383 int out_width = this->out->getShape()[2];
1384 int out_channels = this->out->getShape()[3];
1385
Kevin Chengacb550f2021-06-29 15:32:19 -07001386 ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1387 ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001388
TatWai Chong86c403b2022-06-06 20:46:01 -07001389 int pad_top = this->attribute->pad()[0];
1390 int pad_bottom = this->attribute->pad()[1];
1391 int pad_left = this->attribute->pad()[2];
1392 int pad_right = this->attribute->pad()[3];
1393
Eric Kunzee5e26762020-10-13 16:11:07 -07001394 int kernel_h = this->attribute->kernel()[0];
1395 int kernel_w = this->attribute->kernel()[1];
1396 int stride_h = this->attribute->stride()[0];
1397 int stride_w = this->attribute->stride()[1];
1398
1399 DEBUG_INFO(OP,
1400 "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
TatWai Chong86c403b2022-06-06 20:46:01 -07001401 "stride=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -07001402 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
TatWai Chong86c403b2022-06-06 20:46:01 -07001403 kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001404
1405 Eigen::array<Eigen::Index, 2> im2col_input_dims;
1406 im2col_input_dims[0] = kernel_h * kernel_w;
1407 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
1408
1409 Eigen::array<Eigen::Index, 4> col2im_output_dims;
1410 col2im_output_dims[0] = out_batch;
1411 col2im_output_dims[1] = out_height;
1412 col2im_output_dims[2] = out_width;
1413 col2im_output_dims[3] = out_channels;
1414
TatWai Chong86c403b2022-06-06 20:46:01 -07001415 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1416 pad[0] = std::make_pair(0, 0);
1417 pad[1] = std::make_pair(pad_top, pad_bottom);
1418 pad[2] = std::make_pair(pad_left, pad_right);
1419 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001420
TatWai Chong86c403b2022-06-06 20:46:01 -07001421 ETensor4<InEigenType> input_padded = this->in->getTensor().pad(pad, std::numeric_limits<InEigenType>::lowest());
Eric Kunzee5e26762020-10-13 16:11:07 -07001422
1423 // extract_image_patches() output [N, KH, KW, H * W, C]
1424 // transpose to [KH, KW, N, H * W, C]
1425 // reshape to [KH * KW, N * H * W * C]
1426 //
1427 // Set the padding value to be the most negative value that can be
1428 // represented by the datatype to ensure that any padding values will be equal
1429 // to or smaller than the actual maximum in the KH x KW patch.
1430 ETensor2<InEigenType> input_extract_patches =
1431 input_padded
1432 .extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID,
1433 std::numeric_limits<InEigenType>::lowest())
1434 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
1435 .reshape(im2col_input_dims);
1436
1437 // Get the maximum of the KHxHW patches along axis 0
1438 Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
1439
1440 // 1D result with [N * H * W * C]
1441 ETensor1<OutEigenType> out_1d(this->out->getElementCount());
1442
1443 // index input_patches with argmax array should give the result
1444 for (size_t i = 0; i < this->out->getElementCount(); i++)
1445 {
1446 out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
1447 }
1448
1449 // reshape result to [N, H, W, C]
1450 this->out->getTensor() = out_1d.reshape(col2im_output_dims);
1451
1452 return GraphNode::eval();
1453}
1454
James Ward8b390432022-08-12 20:48:56 +01001455template <DType InDtype, DType WeightDtype, DType AccDtype>
1456OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
Kevin Chengcc61be32021-10-14 17:09:57 -07001457 TosaAttributeBase* attribute_,
Kevin Chengcc61be32021-10-14 17:09:57 -07001458 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001459 : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001460{
1461 setRequiredOperands(3, 1);
1462 setRequiredRank(4);
1463
Kevin Cheng93a16282021-08-31 16:14:03 -07001464 INIT_ATTRIBUTE(TransposeConv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001465}
1466
James Ward8b390432022-08-12 20:48:56 +01001467template <DType InDtype, DType WeightDtype, DType AccDtype>
1468OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::~OpTransposeConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001469{
1470 if (attribute)
1471 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001472}
1473
James Ward8b390432022-08-12 20:48:56 +01001474template <DType InDtype, DType WeightDtype, DType AccDtype>
1475int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001476{
1477 if (validateRequiredOperands())
1478 return 1;
1479
1480 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1481 {
1482 return 1;
1483 }
1484
Kevin Chengcc61be32021-10-14 17:09:57 -07001485 ERROR_IF(outputs[0]->getDtype() != AccDtype,
James Ward8b390432022-08-12 20:48:56 +01001486 "OpTransposeConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001487
Eric Kunzee5e26762020-10-13 16:11:07 -07001488 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1489 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1490 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001491 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001492
TatWai Chong24594f52022-06-08 00:48:04 -07001493 if (attribute->out_pad().size() != 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001494 {
TatWai Chong24594f52022-06-08 00:48:04 -07001495 printNodeValidationError("OpTransposeConv2d: illegal size for attribute out_pad");
Eric Kunzee5e26762020-10-13 16:11:07 -07001496 return 1;
1497 }
1498
1499 if (attribute->stride().size() != 2)
1500 {
1501 printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
1502 return 1;
1503 }
1504
Eric Kunzee5e26762020-10-13 16:11:07 -07001505 if (attribute->output_shape().size() != 4)
1506 {
1507 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1508 return 1;
1509 }
1510
Eric Kunzec1a97832022-07-01 16:56:09 -07001511
Kevin Cheng9fe17242021-11-10 01:04:39 +00001512
1513 for (int32_t i : attribute->stride())
1514 {
1515 if (i < 1)
1516 {
1517 printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one");
1518 return 1;
1519 }
1520 }
1521
Eric Kunzee5e26762020-10-13 16:11:07 -07001522 for (int d = 0; d < 4; d++)
1523 {
1524 if (attribute->output_shape()[d] != this->output->getShape()[d])
1525 {
1526 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1527 return 1;
1528 }
1529 }
1530
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001531 int32_t IH = input->getShape()[1];
1532 int32_t IW = input->getShape()[2];
1533 int32_t OH = output->getShape()[1];
1534 int32_t OW = output->getShape()[2];
1535
1536 int32_t stride_y = attribute->stride()[0];
1537 int32_t stride_x = attribute->stride()[1];
1538 int32_t kernel_h = weight->getShape()[1];
1539 int32_t kernel_w = weight->getShape()[2];
1540
TatWai Chong24594f52022-06-08 00:48:04 -07001541 int32_t out_pad_top = attribute->out_pad()[0];
1542 int32_t out_pad_bottom = attribute->out_pad()[1];
1543 int32_t out_pad_left = attribute->out_pad()[2];
1544 int32_t out_pad_right = attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001545
Eric Kunzec1a97832022-07-01 16:56:09 -07001546 for (size_t i = 0; i < attribute->out_pad().size(); i++)
1547 {
1548 ERROR_IF(attribute->out_pad()[i] <= -(weight->getShape()[(i / 2) + 1]), "OpTransposeConv2d: At least one out_pad value is larger than kernel size");
1549 }
1550
1551 int32_t H = (IH - 1) * stride_y + out_pad_top + out_pad_bottom + kernel_h;
1552 int32_t W = (IW - 1) * stride_x + out_pad_left + out_pad_right + kernel_w;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001553
1554 if ((OH != H) || (OW != W))
1555 {
1556 std::string msg = "OpTransposeConv2d: Mismatch between output shape provided and expected output shape (" +
1557 std::to_string(H) + "," +
1558 std::to_string(W) + ")";
1559 printNodeValidationError(msg.c_str());
1560 return 1;
1561 }
1562
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001563 ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
1564 ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001565
Eric Kunzee5e26762020-10-13 16:11:07 -07001566 return 0;
1567}
1568
James Ward8b390432022-08-12 20:48:56 +01001569template <DType InDtype, DType WeightDtype, DType AccDtype>
1570int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001571{
1572 int in_batch = this->input->getShape()[0];
1573 int in_height = this->input->getShape()[1];
1574 int in_width = this->input->getShape()[2];
1575 int in_channels = this->input->getShape()[3];
1576
1577 int f_out_channels = this->weight->getShape()[0];
1578 int f_height = this->weight->getShape()[1];
1579 int f_width = this->weight->getShape()[2];
1580 int f_in_channels = this->weight->getShape()[3];
1581
1582 int b_out_channels = this->bias->getShape()[0];
1583
1584 int out_batch = this->output->getShape()[0];
1585 int out_height = this->output->getShape()[1];
1586 int out_width = this->output->getShape()[2];
1587 int out_channels = this->output->getShape()[3];
1588
TatWai Chong24594f52022-06-08 00:48:04 -07001589 int out_pad_top = this->attribute->out_pad()[0];
1590 int out_pad_bottom = this->attribute->out_pad()[1];
1591 int out_pad_left = this->attribute->out_pad()[2];
1592 int out_pad_right = this->attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001593
1594 int stride_h = this->attribute->stride()[0];
1595 int stride_w = this->attribute->stride()[1];
Eric Kunzee5e26762020-10-13 16:11:07 -07001596
James Ward8b390432022-08-12 20:48:56 +01001597 tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
1598
Kevin Chengacb550f2021-06-29 15:32:19 -07001599 ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1600 ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1601 in_channels);
1602 ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
1603 f_out_channels, out_channels);
1604 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
1605 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001606
1607 DEBUG_INFO(OP,
1608 "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Ward8b390432022-08-12 20:48:56 +01001609 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d], accum_dtype=%s",
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001610 in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels,
TatWai Chong24594f52022-06-08 00:48:04 -07001611 out_batch, out_height, out_width, out_channels, stride_h, stride_w, out_pad_top,
James Ward8b390432022-08-12 20:48:56 +01001612 out_pad_bottom, out_pad_left, out_pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001613
1614 TIn input_val = this->input->getTensor();
1615 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -07001616 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001617 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001618 input_val = input_val - (InEigenType)attribute->input_zp();
1619 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001620 }
1621
1622 Eigen::array<Eigen::Index, 4> reshape_dim;
1623 reshape_dim.fill(1);
1624 reshape_dim[3] = b_out_channels;
1625
1626 Eigen::array<Eigen::Index, 4> bcast;
1627 bcast[0] = out_batch;
1628 bcast[1] = out_height;
1629 bcast[2] = out_width;
1630 bcast[3] = 1;
1631
1632 // initialize with bias
1633 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
1634
1635 int out_x_origin, out_y_origin;
1636 int out_x, out_y;
1637
1638 // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
1639 for (int ob = 0; ob < out_batch; ob++)
1640 {
1641 for (int ih = 0; ih < in_height; ih++)
1642 {
1643 for (int iw = 0; iw < in_width; iw++)
1644 {
Eric Kunzec1a97832022-07-01 16:56:09 -07001645 out_x_origin = iw * stride_w + out_pad_left;
1646 out_y_origin = ih * stride_h + out_pad_top;
Eric Kunzee5e26762020-10-13 16:11:07 -07001647 for (int ic = 0; ic < in_channels; ic++)
1648 {
1649 for (int fh = 0; fh < f_height; fh++)
1650 {
1651 for (int fw = 0; fw < f_width; fw++)
1652 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001653 out_x = out_x_origin + fw;
1654 out_y = out_y_origin + fh;
Eric Kunzee5e26762020-10-13 16:11:07 -07001655 for (int oc = 0; oc < out_channels; oc++)
1656 {
1657 if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
1658 {
1659 this->output->getTensor()(ob, out_y, out_x, oc) +=
James Ward8b390432022-08-12 20:48:56 +01001660 (OutEigenType) ((AccEigenType)input_val(ob, ih, iw, ic) *
1661 (AccEigenType)weight_val(oc, fh, fw, ic));
Eric Kunzee5e26762020-10-13 16:11:07 -07001662 }
1663 }
1664 }
1665 }
1666 }
1667 }
1668 }
1669 }
1670
1671 if (AccDtype == DType_INT48)
1672 {
James Ward8b390432022-08-12 20:48:56 +01001673 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1674 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001675 }
1676
1677 return GraphNode::eval();
1678}
1679
1680// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +01001681DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -07001682DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
Kevin Cheng3a478572021-01-22 17:21:02 -08001683DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07001684DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
1685
James Ward8b390432022-08-12 20:48:56 +01001686DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16);
1687DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FLOAT);
1688DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FLOAT, FLOAT);
1689DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32);
1690DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32);
Eric Kunzee5e26762020-10-13 16:11:07 -07001691
James Ward8b390432022-08-12 20:48:56 +01001692 // [in_t, weight_t, acc_t]
1693DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16);
1694DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FLOAT);
1695DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FLOAT, FLOAT, FLOAT);
1696DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32);
1697DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32);
1698DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48);
Eric Kunzee5e26762020-10-13 16:11:07 -07001699
James Ward8b390432022-08-12 20:48:56 +01001700DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16);
1701DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FLOAT);
1702DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FLOAT, FLOAT, FLOAT);
1703DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32);
1704DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32);
1705DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48);
Kevin Cheng1533b852021-09-01 12:51:58 -07001706
James Ward8b390432022-08-12 20:48:56 +01001707DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16);
1708DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FLOAT);
1709DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FLOAT, FLOAT, FLOAT);
1710DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32);
1711DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32);
1712DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48);
Eric Kunzee5e26762020-10-13 16:11:07 -07001713
James Ward8b390432022-08-12 20:48:56 +01001714DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16);
1715DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FLOAT);
1716DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FLOAT, FLOAT, FLOAT);
1717DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32);
1718DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32);
1719DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48);
Eric Kunzee5e26762020-10-13 16:11:07 -07001720
James Ward8b390432022-08-12 20:48:56 +01001721DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32);
1722DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48);
1723DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16);
1724DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FLOAT);
1725DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FLOAT, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -07001726
James Ward8b390432022-08-12 20:48:56 +01001727DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
Eric Kunzee5e26762020-10-13 16:11:07 -07001728DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT);
Kevin Cheng3a478572021-01-22 17:21:02 -08001729DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07001730DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
1731
James Ward8b390432022-08-12 20:48:56 +01001732DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16);
1733DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FLOAT);
1734DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FLOAT, FLOAT, FLOAT);
1735DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32);
1736DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32);
1737DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT16, INT8, INT48);